| @@ -47,12 +47,16 @@ const DatasetConfig: FC = () => { | |||
| const { | |||
| currentModel: currentRerankModel, | |||
| currentProvider: currentRerankProvider, | |||
| } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank) | |||
| const onRemove = (id: string) => { | |||
| const filteredDataSets = dataSet.filter(item => item.id !== id) | |||
| setDataSet(filteredDataSets) | |||
| const retrievalConfig = getMultipleRetrievalConfig(datasetConfigs as any, filteredDataSets, dataSet, !!currentRerankModel) | |||
| const retrievalConfig = getMultipleRetrievalConfig(datasetConfigs as any, filteredDataSets, dataSet, { | |||
| provider: currentRerankProvider?.provider, | |||
| model: currentRerankModel?.model, | |||
| }) | |||
| setDatasetConfigs({ | |||
| ...(datasetConfigs as any), | |||
| ...retrievalConfig, | |||
| @@ -172,7 +172,7 @@ const ConfigContent: FC<Props> = ({ | |||
| return false | |||
| return datasetConfigs.reranking_enable | |||
| }, [canManuallyToggleRerank, datasetConfigs.reranking_enable]) | |||
| }, [canManuallyToggleRerank, datasetConfigs.reranking_enable, isRerankDefaultModelValid]) | |||
| const handleDisabledSwitchClick = useCallback(() => { | |||
| if (!currentRerankModel && !showRerankModel) | |||
| @@ -43,6 +43,7 @@ const ParamsConfig = ({ | |||
| const { | |||
| defaultModel: rerankDefaultModel, | |||
| currentModel: isRerankDefaultModelValid, | |||
| currentProvider: rerankDefaultProvider, | |||
| } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank) | |||
| const isValid = () => { | |||
| @@ -91,7 +92,10 @@ const ParamsConfig = ({ | |||
| reranking_mode: restConfigs.reranking_mode, | |||
| weights: restConfigs.weights, | |||
| reranking_enable: restConfigs.reranking_enable, | |||
| }, selectedDatasets, selectedDatasets, !!isRerankDefaultModelValid) | |||
| }, selectedDatasets, selectedDatasets, { | |||
| provider: rerankDefaultProvider?.provider, | |||
| model: isRerankDefaultModelValid?.model, | |||
| }) | |||
| setTempDataSetConfigs({ | |||
| ...retrievalConfig, | |||
| @@ -226,6 +226,7 @@ const Configuration: FC = () => { | |||
| const [rerankSettingModalOpen, setRerankSettingModalOpen] = useState(false) | |||
| const { | |||
| currentModel: currentRerankModel, | |||
| currentProvider: currentRerankProvider, | |||
| } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank) | |||
| const handleSelect = (data: DataSet[]) => { | |||
| if (isEqual(data.map(item => item.id), dataSets.map(item => item.id))) { | |||
| @@ -279,7 +280,10 @@ const Configuration: FC = () => { | |||
| reranking_mode: restConfigs.reranking_mode, | |||
| weights: restConfigs.weights, | |||
| reranking_enable: restConfigs.reranking_enable, | |||
| }, newDatasets, dataSets, !!currentRerankModel) | |||
| }, newDatasets, dataSets, { | |||
| provider: currentRerankProvider?.provider, | |||
| model: currentRerankModel?.model, | |||
| }) | |||
| setDatasetConfigs({ | |||
| ...retrievalConfig, | |||
| @@ -620,7 +624,10 @@ const Configuration: FC = () => { | |||
| syncToPublishedConfig(config) | |||
| setPublishedConfig(config) | |||
| const retrievalConfig = getMultipleRetrievalConfig(modelConfig.dataset_configs, datasets, datasets, !!currentRerankModel) | |||
| const retrievalConfig = getMultipleRetrievalConfig(modelConfig.dataset_configs, datasets, datasets, { | |||
| provider: currentRerankProvider?.provider, | |||
| model: currentRerankModel?.model, | |||
| }) | |||
| setDatasetConfigs({ | |||
| retrieval_model: RETRIEVE_TYPE.multiWay, | |||
| ...modelConfig.dataset_configs, | |||
| @@ -1,7 +1,7 @@ | |||
| import { BlockEnum } from '../../types' | |||
| import type { NodeDefault } from '../../types' | |||
| import type { KnowledgeRetrievalNodeType } from './types' | |||
| import { RerankingModeEnum } from '@/models/datasets' | |||
| import { checkoutRerankModelConfigedInRetrievalSettings } from './utils' | |||
| import { ALL_CHAT_AVAILABLE_BLOCKS, ALL_COMPLETION_AVAILABLE_BLOCKS } from '@/app/components/workflow/constants' | |||
| import { DATASET_DEFAULT } from '@/config' | |||
| import { RETRIEVE_TYPE } from '@/types/app' | |||
| @@ -36,12 +36,17 @@ const nodeDefault: NodeDefault<KnowledgeRetrievalNodeType> = { | |||
| if (!errorMessages && (!payload.dataset_ids || payload.dataset_ids.length === 0)) | |||
| errorMessages = t(`${i18nPrefix}.errorMsg.fieldRequired`, { field: t(`${i18nPrefix}.nodes.knowledgeRetrieval.knowledge`) }) | |||
| if (!errorMessages && payload.retrieval_mode === RETRIEVE_TYPE.multiWay && payload.multiple_retrieval_config?.reranking_mode === RerankingModeEnum.RerankingModel && !payload.multiple_retrieval_config?.reranking_model?.provider && payload.multiple_retrieval_config?.reranking_enable) | |||
| errorMessages = t(`${i18nPrefix}.errorMsg.fieldRequired`, { field: t(`${i18nPrefix}.errorMsg.fields.rerankModel`) }) | |||
| if (!errorMessages && payload.retrieval_mode === RETRIEVE_TYPE.oneWay && !payload.single_retrieval_config?.model?.provider) | |||
| errorMessages = t(`${i18nPrefix}.errorMsg.fieldRequired`, { field: t('common.modelProvider.systemReasoningModel.key') }) | |||
| const { _datasets, multiple_retrieval_config, retrieval_mode } = payload | |||
| if (retrieval_mode === RETRIEVE_TYPE.multiWay) { | |||
| const checked = checkoutRerankModelConfigedInRetrievalSettings(_datasets || [], multiple_retrieval_config) | |||
| if (!errorMessages && !checked) | |||
| errorMessages = t(`${i18nPrefix}.errorMsg.fieldRequired`, { field: t(`${i18nPrefix}.errorMsg.fields.rerankModel`) }) | |||
| } | |||
| return { | |||
| isValid: !errorMessages, | |||
| errorMessage: errorMessages, | |||
| @@ -1,6 +1,7 @@ | |||
| import type { CommonNodeType, ModelConfig, ValueSelector } from '@/app/components/workflow/types' | |||
| import type { RETRIEVE_TYPE } from '@/types/app' | |||
| import type { | |||
| DataSet, | |||
| RerankingModeEnum, | |||
| } from '@/models/datasets' | |||
| @@ -35,4 +36,5 @@ export type KnowledgeRetrievalNodeType = CommonNodeType & { | |||
| retrieval_mode: RETRIEVE_TYPE | |||
| multiple_retrieval_config?: MultipleRetrievalConfig | |||
| single_retrieval_config?: SingleRetrievalConfig | |||
| _datasets?: DataSet[] | |||
| } | |||
| @@ -67,6 +67,7 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => { | |||
| const { | |||
| currentModel: currentRerankModel, | |||
| currentProvider: currentRerankProvider, | |||
| } = useCurrentProviderAndModel( | |||
| rerankModelList, | |||
| rerankDefaultModel | |||
| @@ -163,7 +164,10 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => { | |||
| draft.retrieval_mode = newMode | |||
| if (newMode === RETRIEVE_TYPE.multiWay) { | |||
| const multipleRetrievalConfig = draft.multiple_retrieval_config | |||
| draft.multiple_retrieval_config = getMultipleRetrievalConfig(multipleRetrievalConfig!, selectedDatasets, selectedDatasets, !!currentRerankModel) | |||
| draft.multiple_retrieval_config = getMultipleRetrievalConfig(multipleRetrievalConfig!, selectedDatasets, selectedDatasets, { | |||
| provider: currentRerankProvider?.provider, | |||
| model: currentRerankModel?.model, | |||
| }) | |||
| } | |||
| else { | |||
| const hasSetModel = draft.single_retrieval_config?.model?.provider | |||
| @@ -180,14 +184,17 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => { | |||
| } | |||
| }) | |||
| setInputs(newInputs) | |||
| }, [currentModel?.model, currentModel?.model_properties?.mode, currentProvider?.provider, inputs, setInputs, selectedDatasets, currentRerankModel]) | |||
| }, [currentModel?.model, currentModel?.model_properties?.mode, currentProvider?.provider, inputs, setInputs, selectedDatasets, currentRerankModel, currentRerankProvider]) | |||
| const handleMultipleRetrievalConfigChange = useCallback((newConfig: MultipleRetrievalConfig) => { | |||
| const newInputs = produce(inputs, (draft) => { | |||
| draft.multiple_retrieval_config = getMultipleRetrievalConfig(newConfig!, selectedDatasets, selectedDatasets, !!currentRerankModel) | |||
| draft.multiple_retrieval_config = getMultipleRetrievalConfig(newConfig!, selectedDatasets, selectedDatasets, { | |||
| provider: currentRerankProvider?.provider, | |||
| model: currentRerankModel?.model, | |||
| }) | |||
| }) | |||
| setInputs(newInputs) | |||
| }, [inputs, setInputs, selectedDatasets, currentRerankModel]) | |||
| }, [inputs, setInputs, selectedDatasets, currentRerankModel, currentRerankProvider]) | |||
| // datasets | |||
| useEffect(() => { | |||
| @@ -200,6 +207,7 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => { | |||
| } | |||
| const newInputs = produce(inputs, (draft) => { | |||
| draft.dataset_ids = datasetIds | |||
| draft._datasets = selectedDatasets | |||
| }) | |||
| setInputs(newInputs) | |||
| })() | |||
| @@ -228,10 +236,14 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => { | |||
| } = getSelectedDatasetsMode(newDatasets) | |||
| const newInputs = produce(inputs, (draft) => { | |||
| draft.dataset_ids = newDatasets.map(d => d.id) | |||
| draft._datasets = newDatasets | |||
| if (payload.retrieval_mode === RETRIEVE_TYPE.multiWay && newDatasets.length > 0) { | |||
| const multipleRetrievalConfig = draft.multiple_retrieval_config | |||
| draft.multiple_retrieval_config = getMultipleRetrievalConfig(multipleRetrievalConfig!, newDatasets, selectedDatasets, !!currentRerankModel) | |||
| draft.multiple_retrieval_config = getMultipleRetrievalConfig(multipleRetrievalConfig!, newDatasets, selectedDatasets, { | |||
| provider: currentRerankProvider?.provider, | |||
| model: currentRerankModel?.model, | |||
| }) | |||
| } | |||
| }) | |||
| setInputs(newInputs) | |||
| @@ -243,7 +255,7 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => { | |||
| || allExternal | |||
| ) | |||
| setRerankModelOpen(true) | |||
| }, [inputs, setInputs, payload.retrieval_mode, selectedDatasets, currentRerankModel]) | |||
| }, [inputs, setInputs, payload.retrieval_mode, selectedDatasets, currentRerankModel, currentRerankProvider]) | |||
| const filterVar = useCallback((varPayload: Var) => { | |||
| return varPayload.type === VarType.string | |||
| @@ -94,9 +94,10 @@ export const getMultipleRetrievalConfig = ( | |||
| multipleRetrievalConfig: MultipleRetrievalConfig, | |||
| selectedDatasets: DataSet[], | |||
| originalDatasets: DataSet[], | |||
| isValidRerankModel?: boolean, | |||
| validRerankModel?: { provider?: string; model?: string }, | |||
| ) => { | |||
| const shouldSetWeightDefaultValue = xorBy(selectedDatasets, originalDatasets, 'id').length > 0 | |||
| const rerankModelIsValid = validRerankModel?.provider && validRerankModel?.model | |||
| const { | |||
| allHighQuality, | |||
| @@ -128,18 +129,10 @@ export const getMultipleRetrievalConfig = ( | |||
| reranking_enable: ((allInternal && allEconomic) || allExternal) ? reranking_enable : true, | |||
| } | |||
| if (allEconomic || mixtureHighQualityAndEconomic || inconsistentEmbeddingModel || allExternal || mixtureInternalAndExternal) | |||
| result.reranking_mode = RerankingModeEnum.RerankingModel | |||
| if (allHighQuality && !inconsistentEmbeddingModel && reranking_mode === undefined && allInternal) | |||
| result.reranking_mode = RerankingModeEnum.WeightedScore | |||
| if (allHighQuality && !inconsistentEmbeddingModel && (reranking_mode === RerankingModeEnum.WeightedScore || reranking_mode === undefined) && allInternal && !weights) { | |||
| if (!isValidRerankModel) | |||
| result.reranking_mode = RerankingModeEnum.WeightedScore | |||
| else | |||
| result.reranking_mode = RerankingModeEnum.RerankingModel | |||
| if (!rerankModelIsValid) | |||
| result.reranking_model = undefined | |||
| const setDefaultWeights = () => { | |||
| result.weights = { | |||
| vector_setting: { | |||
| vector_weight: allHighQualityVectorSearch | |||
| @@ -160,31 +153,85 @@ export const getMultipleRetrievalConfig = ( | |||
| } | |||
| } | |||
| if (shouldSetWeightDefaultValue && allHighQuality && !inconsistentEmbeddingModel && (reranking_mode === RerankingModeEnum.WeightedScore || reranking_mode === undefined || !isValidRerankModel) && allInternal && weights) { | |||
| if (!isValidRerankModel) | |||
| result.reranking_mode = RerankingModeEnum.WeightedScore | |||
| else | |||
| if (allEconomic || mixtureHighQualityAndEconomic || inconsistentEmbeddingModel || allExternal || mixtureInternalAndExternal) { | |||
| result.reranking_mode = RerankingModeEnum.RerankingModel | |||
| if (rerankModelIsValid) { | |||
| result.reranking_mode = RerankingModeEnum.RerankingModel | |||
| result.reranking_model = { | |||
| provider: validRerankModel?.provider || '', | |||
| model: validRerankModel?.model || '', | |||
| } | |||
| } | |||
| else { | |||
| result.reranking_model = undefined | |||
| } | |||
| } | |||
| result.weights = { | |||
| vector_setting: { | |||
| vector_weight: allHighQualityVectorSearch | |||
| ? DEFAULT_WEIGHTED_SCORE.allHighQualityVectorSearch.semantic | |||
| : allHighQualityFullTextSearch | |||
| ? DEFAULT_WEIGHTED_SCORE.allHighQualityFullTextSearch.semantic | |||
| : DEFAULT_WEIGHTED_SCORE.other.semantic, | |||
| embedding_provider_name: selectedDatasets[0].embedding_model_provider, | |||
| embedding_model_name: selectedDatasets[0].embedding_model, | |||
| }, | |||
| keyword_setting: { | |||
| keyword_weight: allHighQualityVectorSearch | |||
| ? DEFAULT_WEIGHTED_SCORE.allHighQualityVectorSearch.keyword | |||
| : allHighQualityFullTextSearch | |||
| ? DEFAULT_WEIGHTED_SCORE.allHighQualityFullTextSearch.keyword | |||
| : DEFAULT_WEIGHTED_SCORE.other.keyword, | |||
| }, | |||
| if (allHighQuality && !inconsistentEmbeddingModel && allInternal) { | |||
| if (!reranking_mode) { | |||
| if (validRerankModel?.provider && validRerankModel?.model) { | |||
| result.reranking_mode = RerankingModeEnum.RerankingModel | |||
| result.reranking_model = { | |||
| provider: validRerankModel.provider, | |||
| model: validRerankModel.model, | |||
| } | |||
| } | |||
| else { | |||
| result.reranking_mode = RerankingModeEnum.WeightedScore | |||
| setDefaultWeights() | |||
| } | |||
| } | |||
| if (reranking_mode === RerankingModeEnum.WeightedScore && !weights) | |||
| setDefaultWeights() | |||
| if (reranking_mode === RerankingModeEnum.WeightedScore && weights && shouldSetWeightDefaultValue) { | |||
| if (rerankModelIsValid) { | |||
| result.reranking_mode = RerankingModeEnum.RerankingModel | |||
| result.reranking_model = { | |||
| provider: validRerankModel.provider || '', | |||
| model: validRerankModel.model || '', | |||
| } | |||
| } | |||
| else { | |||
| setDefaultWeights() | |||
| } | |||
| } | |||
| if (reranking_mode === RerankingModeEnum.RerankingModel && !rerankModelIsValid && shouldSetWeightDefaultValue) { | |||
| result.reranking_mode = RerankingModeEnum.WeightedScore | |||
| setDefaultWeights() | |||
| } | |||
| } | |||
| return result | |||
| } | |||
| export const checkoutRerankModelConfigedInRetrievalSettings = ( | |||
| datasets: DataSet[], | |||
| multipleRetrievalConfig?: MultipleRetrievalConfig, | |||
| ) => { | |||
| if (!multipleRetrievalConfig) | |||
| return true | |||
| const { | |||
| allEconomic, | |||
| allExternal, | |||
| } = getSelectedDatasetsMode(datasets) | |||
| const { | |||
| reranking_enable, | |||
| reranking_mode, | |||
| reranking_model, | |||
| } = multipleRetrievalConfig | |||
| if (reranking_mode === RerankingModeEnum.RerankingModel && (!reranking_model?.provider || !reranking_model?.model)) { | |||
| if ((allEconomic || allExternal) && !reranking_enable) | |||
| return true | |||
| return false | |||
| } | |||
| return true | |||
| } | |||