| import ConfigContext from '@/context/debug-configuration' | import ConfigContext from '@/context/debug-configuration' | ||||
| import { AppType } from '@/types/app' | import { AppType } from '@/types/app' | ||||
| import type { DataSet } from '@/models/datasets' | import type { DataSet } from '@/models/datasets' | ||||
| import { | |||||
| getMultipleRetrievalConfig, | |||||
| } from '@/app/components/workflow/nodes/knowledge-retrieval/utils' | |||||
| import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks' | |||||
| import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' | |||||
| const Icon = ( | const Icon = ( | ||||
| <svg width="24" height="24" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg"> | <svg width="24" height="24" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg"> | ||||
| setModelConfig, | setModelConfig, | ||||
| showSelectDataSet, | showSelectDataSet, | ||||
| isAgent, | isAgent, | ||||
| datasetConfigs, | |||||
| setDatasetConfigs, | |||||
| } = useContext(ConfigContext) | } = useContext(ConfigContext) | ||||
| const formattingChangedDispatcher = useFormattingChangedDispatcher() | const formattingChangedDispatcher = useFormattingChangedDispatcher() | ||||
| const hasData = dataSet.length > 0 | const hasData = dataSet.length > 0 | ||||
| const { | |||||
| currentModel: currentRerankModel, | |||||
| } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank) | |||||
| const onRemove = (id: string) => { | const onRemove = (id: string) => { | ||||
| setDataSet(dataSet.filter(item => item.id !== id)) | |||||
| const filteredDataSets = dataSet.filter(item => item.id !== id) | |||||
| setDataSet(filteredDataSets) | |||||
| const retrievalConfig = getMultipleRetrievalConfig(datasetConfigs as any, filteredDataSets, dataSet, !!currentRerankModel) | |||||
| setDatasetConfigs({ | |||||
| ...(datasetConfigs as any), | |||||
| ...retrievalConfig, | |||||
| }) | |||||
| formattingChangedDispatcher() | formattingChangedDispatcher() | ||||
| } | } | ||||
| retrieval_model: RETRIEVE_TYPE.multiWay, | retrieval_model: RETRIEVE_TYPE.multiWay, | ||||
| }, isInWorkflow) | }, isInWorkflow) | ||||
| } | } | ||||
| }, [type]) | |||||
| }, [type, datasetConfigs, isInWorkflow, onChange]) | |||||
| const { | const { | ||||
| modelList: rerankModelList, | modelList: rerankModelList, | 
| import type { DatasetConfigs } from '@/models/debug' | import type { DatasetConfigs } from '@/models/debug' | ||||
| import { | import { | ||||
| getMultipleRetrievalConfig, | getMultipleRetrievalConfig, | ||||
| getSelectedDatasetsMode, | |||||
| } from '@/app/components/workflow/nodes/knowledge-retrieval/utils' | } from '@/app/components/workflow/nodes/knowledge-retrieval/utils' | ||||
| type ParamsConfigProps = { | type ParamsConfigProps = { | ||||
| const [tempDataSetConfigs, setTempDataSetConfigs] = useState(datasetConfigs) | const [tempDataSetConfigs, setTempDataSetConfigs] = useState(datasetConfigs) | ||||
| useEffect(() => { | useEffect(() => { | ||||
| const { | |||||
| allEconomic, | |||||
| allHighQuality, | |||||
| allHighQualityFullTextSearch, | |||||
| allHighQualityVectorSearch, | |||||
| allExternal, | |||||
| mixtureHighQualityAndEconomic, | |||||
| inconsistentEmbeddingModel, | |||||
| mixtureInternalAndExternal, | |||||
| } = getSelectedDatasetsMode(selectedDatasets) | |||||
| if (allEconomic || allHighQuality || allHighQualityFullTextSearch || allHighQualityVectorSearch || (allExternal && selectedDatasets.length === 1)) | |||||
| setRerankSettingModalOpen(false) | |||||
| if (mixtureHighQualityAndEconomic || inconsistentEmbeddingModel || mixtureInternalAndExternal || (allExternal && selectedDatasets.length > 1)) | |||||
| setRerankSettingModalOpen(true) | |||||
| }, [selectedDatasets]) | |||||
| useEffect(() => { | |||||
| const { | |||||
| allEconomic, | |||||
| allInternal, | |||||
| allExternal, | |||||
| } = getSelectedDatasetsMode(selectedDatasets) | |||||
| const { datasets, retrieval_model, score_threshold_enabled, ...restConfigs } = datasetConfigs | |||||
| let rerankEnable = restConfigs.reranking_enable | |||||
| if (((allInternal && allEconomic) || allExternal) && !restConfigs.reranking_model?.reranking_provider_name && rerankEnable === undefined) | |||||
| rerankEnable = false | |||||
| setTempDataSetConfigs({ | |||||
| ...getMultipleRetrievalConfig({ | |||||
| top_k: restConfigs.top_k, | |||||
| score_threshold: restConfigs.score_threshold, | |||||
| reranking_model: restConfigs.reranking_model && { | |||||
| provider: restConfigs.reranking_model.reranking_provider_name, | |||||
| model: restConfigs.reranking_model.reranking_model_name, | |||||
| }, | |||||
| reranking_mode: restConfigs.reranking_mode, | |||||
| weights: restConfigs.weights, | |||||
| reranking_enable: rerankEnable, | |||||
| }, selectedDatasets), | |||||
| reranking_model: restConfigs.reranking_model && { | |||||
| reranking_provider_name: restConfigs.reranking_model.reranking_provider_name, | |||||
| reranking_model_name: restConfigs.reranking_model.reranking_model_name, | |||||
| }, | |||||
| retrieval_model, | |||||
| score_threshold_enabled, | |||||
| datasets, | |||||
| }) | |||||
| }, [selectedDatasets, datasetConfigs]) | |||||
| setTempDataSetConfigs(datasetConfigs) | |||||
| }, [datasetConfigs]) | |||||
| const { | const { | ||||
| defaultModel: rerankDefaultModel, | defaultModel: rerankDefaultModel, | ||||
| reranking_mode: restConfigs.reranking_mode, | reranking_mode: restConfigs.reranking_mode, | ||||
| weights: restConfigs.weights, | weights: restConfigs.weights, | ||||
| reranking_enable: restConfigs.reranking_enable, | reranking_enable: restConfigs.reranking_enable, | ||||
| }, selectedDatasets) | |||||
| }, selectedDatasets, selectedDatasets, !!isRerankDefaultModelValid) | |||||
| setTempDataSetConfigs({ | setTempDataSetConfigs({ | ||||
| ...retrievalConfig, | ...retrievalConfig, | ||||
| <div className='mt-6 flex justify-end'> | <div className='mt-6 flex justify-end'> | ||||
| <Button className='mr-2 flex-shrink-0' onClick={() => { | <Button className='mr-2 flex-shrink-0' onClick={() => { | ||||
| setTempDataSetConfigs(datasetConfigs) | |||||
| setRerankSettingModalOpen(false) | setRerankSettingModalOpen(false) | ||||
| }}>{t('common.operation.cancel')}</Button> | }}>{t('common.operation.cancel')}</Button> | ||||
| <Button variant='primary' className='flex-shrink-0' onClick={handleSave} >{t('common.operation.save')}</Button> | <Button variant='primary' className='flex-shrink-0' onClick={handleSave} >{t('common.operation.save')}</Button> | 
| import Config from '@/app/components/app/configuration/config' | import Config from '@/app/components/app/configuration/config' | ||||
| import Debug from '@/app/components/app/configuration/debug' | import Debug from '@/app/components/app/configuration/debug' | ||||
| import Confirm from '@/app/components/base/confirm' | import Confirm from '@/app/components/base/confirm' | ||||
| import { ModelFeatureEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' | |||||
| import { ModelFeatureEnum, ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' | |||||
| import { ToastContext } from '@/app/components/base/toast' | import { ToastContext } from '@/app/components/base/toast' | ||||
| import { fetchAppDetail, updateAppModelConfig } from '@/service/apps' | import { fetchAppDetail, updateAppModelConfig } from '@/service/apps' | ||||
| import { promptVariablesToUserInputsForm, userInputsFormToPromptVariables } from '@/utils/model-config' | import { promptVariablesToUserInputsForm, userInputsFormToPromptVariables } from '@/utils/model-config' | ||||
| import Drawer from '@/app/components/base/drawer' | import Drawer from '@/app/components/base/drawer' | ||||
| import ModelParameterModal from '@/app/components/header/account-setting/model-provider-page/model-parameter-modal' | import ModelParameterModal from '@/app/components/header/account-setting/model-provider-page/model-parameter-modal' | ||||
| import type { FormValue } from '@/app/components/header/account-setting/model-provider-page/declarations' | import type { FormValue } from '@/app/components/header/account-setting/model-provider-page/declarations' | ||||
| import { useTextGenerationCurrentProviderAndModelAndModelList } from '@/app/components/header/account-setting/model-provider-page/hooks' | |||||
| import { | |||||
| useModelListAndDefaultModelAndCurrentProviderAndModel, | |||||
| useTextGenerationCurrentProviderAndModelAndModelList, | |||||
| } from '@/app/components/header/account-setting/model-provider-page/hooks' | |||||
| import { fetchCollectionList } from '@/service/tools' | import { fetchCollectionList } from '@/service/tools' | ||||
| import { type Collection } from '@/app/components/tools/types' | import { type Collection } from '@/app/components/tools/types' | ||||
| import { useStore as useAppStore } from '@/app/components/app/store' | import { useStore as useAppStore } from '@/app/components/app/store' | ||||
| const [isShowSelectDataSet, { setTrue: showSelectDataSet, setFalse: hideSelectDataSet }] = useBoolean(false) | const [isShowSelectDataSet, { setTrue: showSelectDataSet, setFalse: hideSelectDataSet }] = useBoolean(false) | ||||
| const selectedIds = dataSets.map(item => item.id) | const selectedIds = dataSets.map(item => item.id) | ||||
| const [rerankSettingModalOpen, setRerankSettingModalOpen] = useState(false) | const [rerankSettingModalOpen, setRerankSettingModalOpen] = useState(false) | ||||
| const { | |||||
| currentModel: currentRerankModel, | |||||
| } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank) | |||||
| const handleSelect = (data: DataSet[]) => { | const handleSelect = (data: DataSet[]) => { | ||||
| if (isEqual(data.map(item => item.id), dataSets.map(item => item.id))) { | if (isEqual(data.map(item => item.id), dataSets.map(item => item.id))) { | ||||
| hideSelectDataSet() | hideSelectDataSet() | ||||
| reranking_mode: restConfigs.reranking_mode, | reranking_mode: restConfigs.reranking_mode, | ||||
| weights: restConfigs.weights, | weights: restConfigs.weights, | ||||
| reranking_enable: restConfigs.reranking_enable, | reranking_enable: restConfigs.reranking_enable, | ||||
| }, newDatasets) | |||||
| }, newDatasets, dataSets, !!currentRerankModel) | |||||
| setDatasetConfigs({ | setDatasetConfigs({ | ||||
| ...retrievalConfig, | ...retrievalConfig, | ||||
| syncToPublishedConfig(config) | syncToPublishedConfig(config) | ||||
| setPublishedConfig(config) | setPublishedConfig(config) | ||||
| const retrievalConfig = getMultipleRetrievalConfig(modelConfig.dataset_configs, datasets, datasets, !!currentRerankModel) | |||||
| setDatasetConfigs({ | setDatasetConfigs({ | ||||
| retrieval_model: RETRIEVE_TYPE.multiWay, | retrieval_model: RETRIEVE_TYPE.multiWay, | ||||
| ...modelConfig.dataset_configs, | ...modelConfig.dataset_configs, | ||||
| ...retrievalConfig, | |||||
| }) | }) | ||||
| setHasFetchedDetail(true) | setHasFetchedDetail(true) | ||||
| }) | }) | 
| draft.retrieval_mode = newMode | draft.retrieval_mode = newMode | ||||
| if (newMode === RETRIEVE_TYPE.multiWay) { | if (newMode === RETRIEVE_TYPE.multiWay) { | ||||
| const multipleRetrievalConfig = draft.multiple_retrieval_config | const multipleRetrievalConfig = draft.multiple_retrieval_config | ||||
| draft.multiple_retrieval_config = getMultipleRetrievalConfig(multipleRetrievalConfig!, selectedDatasets) | |||||
| draft.multiple_retrieval_config = getMultipleRetrievalConfig(multipleRetrievalConfig!, selectedDatasets, selectedDatasets, !!currentRerankModel) | |||||
| } | } | ||||
| else { | else { | ||||
| const hasSetModel = draft.single_retrieval_config?.model?.provider | const hasSetModel = draft.single_retrieval_config?.model?.provider | ||||
| } | } | ||||
| }) | }) | ||||
| setInputs(newInputs) | setInputs(newInputs) | ||||
| }, [currentModel?.model, currentModel?.model_properties?.mode, currentProvider?.provider, inputs, setInputs, selectedDatasets]) | |||||
| }, [currentModel?.model, currentModel?.model_properties?.mode, currentProvider?.provider, inputs, setInputs, selectedDatasets, currentRerankModel]) | |||||
| const handleMultipleRetrievalConfigChange = useCallback((newConfig: MultipleRetrievalConfig) => { | const handleMultipleRetrievalConfigChange = useCallback((newConfig: MultipleRetrievalConfig) => { | ||||
| const newInputs = produce(inputs, (draft) => { | const newInputs = produce(inputs, (draft) => { | ||||
| draft.multiple_retrieval_config = getMultipleRetrievalConfig(newConfig!, selectedDatasets) | |||||
| draft.multiple_retrieval_config = getMultipleRetrievalConfig(newConfig!, selectedDatasets, selectedDatasets, !!currentRerankModel) | |||||
| }) | }) | ||||
| setInputs(newInputs) | setInputs(newInputs) | ||||
| }, [inputs, setInputs, selectedDatasets]) | |||||
| }, [inputs, setInputs, selectedDatasets, currentRerankModel]) | |||||
| // datasets | // datasets | ||||
| useEffect(() => { | useEffect(() => { | ||||
| if (payload.retrieval_mode === RETRIEVE_TYPE.multiWay && newDatasets.length > 0) { | if (payload.retrieval_mode === RETRIEVE_TYPE.multiWay && newDatasets.length > 0) { | ||||
| const multipleRetrievalConfig = draft.multiple_retrieval_config | const multipleRetrievalConfig = draft.multiple_retrieval_config | ||||
| draft.multiple_retrieval_config = getMultipleRetrievalConfig(multipleRetrievalConfig!, newDatasets) | |||||
| draft.multiple_retrieval_config = getMultipleRetrievalConfig(multipleRetrievalConfig!, newDatasets, selectedDatasets, !!currentRerankModel) | |||||
| } | } | ||||
| }) | }) | ||||
| setInputs(newInputs) | setInputs(newInputs) | ||||
| || (allExternal && newDatasets.length > 1) | || (allExternal && newDatasets.length > 1) | ||||
| ) | ) | ||||
| setRerankModelOpen(true) | setRerankModelOpen(true) | ||||
| }, [inputs, setInputs, payload.retrieval_mode]) | |||||
| }, [inputs, setInputs, payload.retrieval_mode, selectedDatasets, currentRerankModel]) | |||||
| const filterVar = useCallback((varPayload: Var) => { | const filterVar = useCallback((varPayload: Var) => { | ||||
| return varPayload.type === VarType.string | return varPayload.type === VarType.string | 
| import { uniq } from 'lodash-es' | |||||
| import { | |||||
| uniq, | |||||
| xorBy, | |||||
| } from 'lodash-es' | |||||
| import type { MultipleRetrievalConfig } from './types' | import type { MultipleRetrievalConfig } from './types' | ||||
| import type { | import type { | ||||
| DataSet, | DataSet, | ||||
| return true | return true | ||||
| } | } | ||||
| export const getSelectedDatasetsMode = (datasets: DataSet[]) => { | |||||
| export const getSelectedDatasetsMode = (datasets: DataSet[] = []) => { | |||||
| if (datasets === null) | |||||
| datasets = [] | |||||
| let allHighQuality = true | let allHighQuality = true | ||||
| let allHighQualityVectorSearch = true | let allHighQualityVectorSearch = true | ||||
| let allHighQualityFullTextSearch = true | let allHighQualityFullTextSearch = true | ||||
| } as SelectedDatasetsMode | } as SelectedDatasetsMode | ||||
| } | } | ||||
| export const getMultipleRetrievalConfig = (multipleRetrievalConfig: MultipleRetrievalConfig, selectedDatasets: DataSet[]) => { | |||||
| export const getMultipleRetrievalConfig = ( | |||||
| multipleRetrievalConfig: MultipleRetrievalConfig, | |||||
| selectedDatasets: DataSet[], | |||||
| originalDatasets: DataSet[], | |||||
| isValidRerankModel?: boolean, | |||||
| ) => { | |||||
| const shouldSetWeightDefaultValue = xorBy(selectedDatasets, originalDatasets, 'id').length > 0 | |||||
| const { | const { | ||||
| allHighQuality, | allHighQuality, | ||||
| allHighQualityVectorSearch, | allHighQualityVectorSearch, | ||||
| result.reranking_mode = RerankingModeEnum.WeightedScore | result.reranking_mode = RerankingModeEnum.WeightedScore | ||||
| if (allHighQuality && !inconsistentEmbeddingModel && (reranking_mode === RerankingModeEnum.WeightedScore || reranking_mode === undefined) && allInternal && !weights) { | if (allHighQuality && !inconsistentEmbeddingModel && (reranking_mode === RerankingModeEnum.WeightedScore || reranking_mode === undefined) && allInternal && !weights) { | ||||
| if (!isValidRerankModel) | |||||
| result.reranking_mode = RerankingModeEnum.WeightedScore | |||||
| else | |||||
| result.reranking_mode = RerankingModeEnum.RerankingModel | |||||
| result.weights = { | |||||
| vector_setting: { | |||||
| vector_weight: allHighQualityVectorSearch | |||||
| ? DEFAULT_WEIGHTED_SCORE.allHighQualityVectorSearch.semantic | |||||
| : allHighQualityFullTextSearch | |||||
| ? DEFAULT_WEIGHTED_SCORE.allHighQualityFullTextSearch.semantic | |||||
| : DEFAULT_WEIGHTED_SCORE.other.semantic, | |||||
| embedding_provider_name: selectedDatasets[0].embedding_model_provider, | |||||
| embedding_model_name: selectedDatasets[0].embedding_model, | |||||
| }, | |||||
| keyword_setting: { | |||||
| keyword_weight: allHighQualityVectorSearch | |||||
| ? DEFAULT_WEIGHTED_SCORE.allHighQualityVectorSearch.keyword | |||||
| : allHighQualityFullTextSearch | |||||
| ? DEFAULT_WEIGHTED_SCORE.allHighQualityFullTextSearch.keyword | |||||
| : DEFAULT_WEIGHTED_SCORE.other.keyword, | |||||
| }, | |||||
| } | |||||
| } | |||||
| if (shouldSetWeightDefaultValue && allHighQuality && !inconsistentEmbeddingModel && (reranking_mode === RerankingModeEnum.WeightedScore || reranking_mode === undefined || !isValidRerankModel) && allInternal && weights) { | |||||
| if (!isValidRerankModel) | |||||
| result.reranking_mode = RerankingModeEnum.WeightedScore | |||||
| else | |||||
| result.reranking_mode = RerankingModeEnum.RerankingModel | |||||
| result.weights = { | result.weights = { | ||||
| vector_setting: { | vector_setting: { | ||||
| vector_weight: allHighQualityVectorSearch | vector_weight: allHighQualityVectorSearch | 
| semantic: 0, | semantic: 0, | ||||
| keyword: 1.0, | keyword: 1.0, | ||||
| }, | }, | ||||
| semanticFirst: { | |||||
| semantic: 0.7, | |||||
| keyword: 0.3, | |||||
| }, | |||||
| keywordFirst: { | |||||
| semantic: 0.3, | |||||
| keyword: 0.7, | |||||
| }, | |||||
| other: { | other: { | ||||
| semantic: 0.7, | semantic: 0.7, | ||||
| keyword: 0.3, | keyword: 0.3, |