| @@ -13,6 +13,11 @@ import ContextVar from './context-var' | |||
| import ConfigContext from '@/context/debug-configuration' | |||
| import { AppType } from '@/types/app' | |||
| import type { DataSet } from '@/models/datasets' | |||
| import { | |||
| getMultipleRetrievalConfig, | |||
| } from '@/app/components/workflow/nodes/knowledge-retrieval/utils' | |||
| import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks' | |||
| import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' | |||
| const Icon = ( | |||
| <svg width="24" height="24" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg"> | |||
| @@ -31,13 +36,25 @@ const DatasetConfig: FC = () => { | |||
| setModelConfig, | |||
| showSelectDataSet, | |||
| isAgent, | |||
| datasetConfigs, | |||
| setDatasetConfigs, | |||
| } = useContext(ConfigContext) | |||
| const formattingChangedDispatcher = useFormattingChangedDispatcher() | |||
| const hasData = dataSet.length > 0 | |||
| const { | |||
| currentModel: currentRerankModel, | |||
| } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank) | |||
| const onRemove = (id: string) => { | |||
| setDataSet(dataSet.filter(item => item.id !== id)) | |||
| const filteredDataSets = dataSet.filter(item => item.id !== id) | |||
| setDataSet(filteredDataSets) | |||
| const retrievalConfig = getMultipleRetrievalConfig(datasetConfigs as any, filteredDataSets, dataSet, !!currentRerankModel) | |||
| setDatasetConfigs({ | |||
| ...(datasetConfigs as any), | |||
| ...retrievalConfig, | |||
| }) | |||
| formattingChangedDispatcher() | |||
| } | |||
| @@ -55,7 +55,7 @@ const ConfigContent: FC<Props> = ({ | |||
| retrieval_model: RETRIEVE_TYPE.multiWay, | |||
| }, isInWorkflow) | |||
| } | |||
| }, [type]) | |||
| }, [type, datasetConfigs, isInWorkflow, onChange]) | |||
| const { | |||
| modelList: rerankModelList, | |||
| @@ -16,7 +16,6 @@ import type { DataSet } from '@/models/datasets' | |||
| import type { DatasetConfigs } from '@/models/debug' | |||
| import { | |||
| getMultipleRetrievalConfig, | |||
| getSelectedDatasetsMode, | |||
| } from '@/app/components/workflow/nodes/knowledge-retrieval/utils' | |||
| type ParamsConfigProps = { | |||
| @@ -37,57 +36,8 @@ const ParamsConfig = ({ | |||
| const [tempDataSetConfigs, setTempDataSetConfigs] = useState(datasetConfigs) | |||
| useEffect(() => { | |||
| const { | |||
| allEconomic, | |||
| allHighQuality, | |||
| allHighQualityFullTextSearch, | |||
| allHighQualityVectorSearch, | |||
| allExternal, | |||
| mixtureHighQualityAndEconomic, | |||
| inconsistentEmbeddingModel, | |||
| mixtureInternalAndExternal, | |||
| } = getSelectedDatasetsMode(selectedDatasets) | |||
| if (allEconomic || allHighQuality || allHighQualityFullTextSearch || allHighQualityVectorSearch || (allExternal && selectedDatasets.length === 1)) | |||
| setRerankSettingModalOpen(false) | |||
| if (mixtureHighQualityAndEconomic || inconsistentEmbeddingModel || mixtureInternalAndExternal || (allExternal && selectedDatasets.length > 1)) | |||
| setRerankSettingModalOpen(true) | |||
| }, [selectedDatasets]) | |||
| useEffect(() => { | |||
| const { | |||
| allEconomic, | |||
| allInternal, | |||
| allExternal, | |||
| } = getSelectedDatasetsMode(selectedDatasets) | |||
| const { datasets, retrieval_model, score_threshold_enabled, ...restConfigs } = datasetConfigs | |||
| let rerankEnable = restConfigs.reranking_enable | |||
| if (((allInternal && allEconomic) || allExternal) && !restConfigs.reranking_model?.reranking_provider_name && rerankEnable === undefined) | |||
| rerankEnable = false | |||
| setTempDataSetConfigs({ | |||
| ...getMultipleRetrievalConfig({ | |||
| top_k: restConfigs.top_k, | |||
| score_threshold: restConfigs.score_threshold, | |||
| reranking_model: restConfigs.reranking_model && { | |||
| provider: restConfigs.reranking_model.reranking_provider_name, | |||
| model: restConfigs.reranking_model.reranking_model_name, | |||
| }, | |||
| reranking_mode: restConfigs.reranking_mode, | |||
| weights: restConfigs.weights, | |||
| reranking_enable: rerankEnable, | |||
| }, selectedDatasets), | |||
| reranking_model: restConfigs.reranking_model && { | |||
| reranking_provider_name: restConfigs.reranking_model.reranking_provider_name, | |||
| reranking_model_name: restConfigs.reranking_model.reranking_model_name, | |||
| }, | |||
| retrieval_model, | |||
| score_threshold_enabled, | |||
| datasets, | |||
| }) | |||
| }, [selectedDatasets, datasetConfigs]) | |||
| setTempDataSetConfigs(datasetConfigs) | |||
| }, [datasetConfigs]) | |||
| const { | |||
| defaultModel: rerankDefaultModel, | |||
| @@ -135,7 +85,7 @@ const ParamsConfig = ({ | |||
| reranking_mode: restConfigs.reranking_mode, | |||
| weights: restConfigs.weights, | |||
| reranking_enable: restConfigs.reranking_enable, | |||
| }, selectedDatasets) | |||
| }, selectedDatasets, selectedDatasets, !!isRerankDefaultModelValid) | |||
| setTempDataSetConfigs({ | |||
| ...retrievalConfig, | |||
| @@ -180,6 +130,7 @@ const ParamsConfig = ({ | |||
| <div className='mt-6 flex justify-end'> | |||
| <Button className='mr-2 flex-shrink-0' onClick={() => { | |||
| setTempDataSetConfigs(datasetConfigs) | |||
| setRerankSettingModalOpen(false) | |||
| }}>{t('common.operation.cancel')}</Button> | |||
| <Button variant='primary' className='flex-shrink-0' onClick={handleSave} >{t('common.operation.save')}</Button> | |||
| @@ -38,7 +38,7 @@ import ConfigContext from '@/context/debug-configuration' | |||
| import Config from '@/app/components/app/configuration/config' | |||
| import Debug from '@/app/components/app/configuration/debug' | |||
| import Confirm from '@/app/components/base/confirm' | |||
| import { ModelFeatureEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' | |||
| import { ModelFeatureEnum, ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' | |||
| import { ToastContext } from '@/app/components/base/toast' | |||
| import { fetchAppDetail, updateAppModelConfig } from '@/service/apps' | |||
| import { promptVariablesToUserInputsForm, userInputsFormToPromptVariables } from '@/utils/model-config' | |||
| @@ -53,7 +53,10 @@ import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' | |||
| import Drawer from '@/app/components/base/drawer' | |||
| import ModelParameterModal from '@/app/components/header/account-setting/model-provider-page/model-parameter-modal' | |||
| import type { FormValue } from '@/app/components/header/account-setting/model-provider-page/declarations' | |||
| import { useTextGenerationCurrentProviderAndModelAndModelList } from '@/app/components/header/account-setting/model-provider-page/hooks' | |||
| import { | |||
| useModelListAndDefaultModelAndCurrentProviderAndModel, | |||
| useTextGenerationCurrentProviderAndModelAndModelList, | |||
| } from '@/app/components/header/account-setting/model-provider-page/hooks' | |||
| import { fetchCollectionList } from '@/service/tools' | |||
| import { type Collection } from '@/app/components/tools/types' | |||
| import { useStore as useAppStore } from '@/app/components/app/store' | |||
| @@ -217,6 +220,9 @@ const Configuration: FC = () => { | |||
| const [isShowSelectDataSet, { setTrue: showSelectDataSet, setFalse: hideSelectDataSet }] = useBoolean(false) | |||
| const selectedIds = dataSets.map(item => item.id) | |||
| const [rerankSettingModalOpen, setRerankSettingModalOpen] = useState(false) | |||
| const { | |||
| currentModel: currentRerankModel, | |||
| } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank) | |||
| const handleSelect = (data: DataSet[]) => { | |||
| if (isEqual(data.map(item => item.id), dataSets.map(item => item.id))) { | |||
| hideSelectDataSet() | |||
| @@ -263,7 +269,7 @@ const Configuration: FC = () => { | |||
| reranking_mode: restConfigs.reranking_mode, | |||
| weights: restConfigs.weights, | |||
| reranking_enable: restConfigs.reranking_enable, | |||
| }, newDatasets) | |||
| }, newDatasets, dataSets, !!currentRerankModel) | |||
| setDatasetConfigs({ | |||
| ...retrievalConfig, | |||
| @@ -603,9 +609,11 @@ const Configuration: FC = () => { | |||
| syncToPublishedConfig(config) | |||
| setPublishedConfig(config) | |||
| const retrievalConfig = getMultipleRetrievalConfig(modelConfig.dataset_configs, datasets, datasets, !!currentRerankModel) | |||
| setDatasetConfigs({ | |||
| retrieval_model: RETRIEVE_TYPE.multiWay, | |||
| ...modelConfig.dataset_configs, | |||
| ...retrievalConfig, | |||
| }) | |||
| setHasFetchedDetail(true) | |||
| }) | |||
| @@ -163,7 +163,7 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => { | |||
| draft.retrieval_mode = newMode | |||
| if (newMode === RETRIEVE_TYPE.multiWay) { | |||
| const multipleRetrievalConfig = draft.multiple_retrieval_config | |||
| draft.multiple_retrieval_config = getMultipleRetrievalConfig(multipleRetrievalConfig!, selectedDatasets) | |||
| draft.multiple_retrieval_config = getMultipleRetrievalConfig(multipleRetrievalConfig!, selectedDatasets, selectedDatasets, !!currentRerankModel) | |||
| } | |||
| else { | |||
| const hasSetModel = draft.single_retrieval_config?.model?.provider | |||
| @@ -180,14 +180,14 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => { | |||
| } | |||
| }) | |||
| setInputs(newInputs) | |||
| }, [currentModel?.model, currentModel?.model_properties?.mode, currentProvider?.provider, inputs, setInputs, selectedDatasets]) | |||
| }, [currentModel?.model, currentModel?.model_properties?.mode, currentProvider?.provider, inputs, setInputs, selectedDatasets, currentRerankModel]) | |||
| const handleMultipleRetrievalConfigChange = useCallback((newConfig: MultipleRetrievalConfig) => { | |||
| const newInputs = produce(inputs, (draft) => { | |||
| draft.multiple_retrieval_config = getMultipleRetrievalConfig(newConfig!, selectedDatasets) | |||
| draft.multiple_retrieval_config = getMultipleRetrievalConfig(newConfig!, selectedDatasets, selectedDatasets, !!currentRerankModel) | |||
| }) | |||
| setInputs(newInputs) | |||
| }, [inputs, setInputs, selectedDatasets]) | |||
| }, [inputs, setInputs, selectedDatasets, currentRerankModel]) | |||
| // datasets | |||
| useEffect(() => { | |||
| @@ -231,7 +231,7 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => { | |||
| if (payload.retrieval_mode === RETRIEVE_TYPE.multiWay && newDatasets.length > 0) { | |||
| const multipleRetrievalConfig = draft.multiple_retrieval_config | |||
| draft.multiple_retrieval_config = getMultipleRetrievalConfig(multipleRetrievalConfig!, newDatasets) | |||
| draft.multiple_retrieval_config = getMultipleRetrievalConfig(multipleRetrievalConfig!, newDatasets, selectedDatasets, !!currentRerankModel) | |||
| } | |||
| }) | |||
| setInputs(newInputs) | |||
| @@ -243,7 +243,7 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => { | |||
| || (allExternal && newDatasets.length > 1) | |||
| ) | |||
| setRerankModelOpen(true) | |||
| }, [inputs, setInputs, payload.retrieval_mode]) | |||
| }, [inputs, setInputs, payload.retrieval_mode, selectedDatasets, currentRerankModel]) | |||
| const filterVar = useCallback((varPayload: Var) => { | |||
| return varPayload.type === VarType.string | |||
| @@ -1,4 +1,7 @@ | |||
| import { uniq } from 'lodash-es' | |||
| import { | |||
| uniq, | |||
| xorBy, | |||
| } from 'lodash-es' | |||
| import type { MultipleRetrievalConfig } from './types' | |||
| import type { | |||
| DataSet, | |||
| @@ -15,7 +18,9 @@ export const checkNodeValid = () => { | |||
| return true | |||
| } | |||
| export const getSelectedDatasetsMode = (datasets: DataSet[]) => { | |||
| export const getSelectedDatasetsMode = (datasets: DataSet[] = []) => { | |||
| if (datasets === null) | |||
| datasets = [] | |||
| let allHighQuality = true | |||
| let allHighQualityVectorSearch = true | |||
| let allHighQualityFullTextSearch = true | |||
| @@ -85,7 +90,14 @@ export const getSelectedDatasetsMode = (datasets: DataSet[]) => { | |||
| } as SelectedDatasetsMode | |||
| } | |||
| export const getMultipleRetrievalConfig = (multipleRetrievalConfig: MultipleRetrievalConfig, selectedDatasets: DataSet[]) => { | |||
| export const getMultipleRetrievalConfig = ( | |||
| multipleRetrievalConfig: MultipleRetrievalConfig, | |||
| selectedDatasets: DataSet[], | |||
| originalDatasets: DataSet[], | |||
| isValidRerankModel?: boolean, | |||
| ) => { | |||
| const shouldSetWeightDefaultValue = xorBy(selectedDatasets, originalDatasets, 'id').length > 0 | |||
| const { | |||
| allHighQuality, | |||
| allHighQualityVectorSearch, | |||
| @@ -123,6 +135,37 @@ export const getMultipleRetrievalConfig = (multipleRetrievalConfig: MultipleRetr | |||
| result.reranking_mode = RerankingModeEnum.WeightedScore | |||
| if (allHighQuality && !inconsistentEmbeddingModel && (reranking_mode === RerankingModeEnum.WeightedScore || reranking_mode === undefined) && allInternal && !weights) { | |||
| if (!isValidRerankModel) | |||
| result.reranking_mode = RerankingModeEnum.WeightedScore | |||
| else | |||
| result.reranking_mode = RerankingModeEnum.RerankingModel | |||
| result.weights = { | |||
| vector_setting: { | |||
| vector_weight: allHighQualityVectorSearch | |||
| ? DEFAULT_WEIGHTED_SCORE.allHighQualityVectorSearch.semantic | |||
| : allHighQualityFullTextSearch | |||
| ? DEFAULT_WEIGHTED_SCORE.allHighQualityFullTextSearch.semantic | |||
| : DEFAULT_WEIGHTED_SCORE.other.semantic, | |||
| embedding_provider_name: selectedDatasets[0].embedding_model_provider, | |||
| embedding_model_name: selectedDatasets[0].embedding_model, | |||
| }, | |||
| keyword_setting: { | |||
| keyword_weight: allHighQualityVectorSearch | |||
| ? DEFAULT_WEIGHTED_SCORE.allHighQualityVectorSearch.keyword | |||
| : allHighQualityFullTextSearch | |||
| ? DEFAULT_WEIGHTED_SCORE.allHighQualityFullTextSearch.keyword | |||
| : DEFAULT_WEIGHTED_SCORE.other.keyword, | |||
| }, | |||
| } | |||
| } | |||
| if (shouldSetWeightDefaultValue && allHighQuality && !inconsistentEmbeddingModel && (reranking_mode === RerankingModeEnum.WeightedScore || reranking_mode === undefined || !isValidRerankModel) && allInternal && weights) { | |||
| if (!isValidRerankModel) | |||
| result.reranking_mode = RerankingModeEnum.WeightedScore | |||
| else | |||
| result.reranking_mode = RerankingModeEnum.RerankingModel | |||
| result.weights = { | |||
| vector_setting: { | |||
| vector_weight: allHighQualityVectorSearch | |||
| @@ -566,14 +566,6 @@ export const DEFAULT_WEIGHTED_SCORE = { | |||
| semantic: 0, | |||
| keyword: 1.0, | |||
| }, | |||
| semanticFirst: { | |||
| semantic: 0.7, | |||
| keyword: 0.3, | |||
| }, | |||
| keywordFirst: { | |||
| semantic: 0.3, | |||
| keyword: 0.7, | |||
| }, | |||
| other: { | |||
| semantic: 0.7, | |||
| keyword: 0.3, | |||