Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>tags/1.9.1
| @@ -65,13 +65,40 @@ const DatasetConfig: FC = () => { | |||
| const onRemove = (id: string) => { | |||
| const filteredDataSets = dataSet.filter(item => item.id !== id) | |||
| setDataSet(filteredDataSets) | |||
| const retrievalConfig = getMultipleRetrievalConfig(datasetConfigs as any, filteredDataSets, dataSet, { | |||
| const { datasets, retrieval_model, score_threshold_enabled, ...restConfigs } = datasetConfigs | |||
| const { | |||
| top_k, | |||
| score_threshold, | |||
| reranking_model, | |||
| reranking_mode, | |||
| weights, | |||
| reranking_enable, | |||
| } = restConfigs | |||
| const oldRetrievalConfig = { | |||
| top_k, | |||
| score_threshold, | |||
| reranking_model: (reranking_model.reranking_provider_name && reranking_model.reranking_model_name) ? { | |||
| provider: reranking_model.reranking_provider_name, | |||
| model: reranking_model.reranking_model_name, | |||
| } : undefined, | |||
| reranking_mode, | |||
| weights, | |||
| reranking_enable, | |||
| } | |||
| const retrievalConfig = getMultipleRetrievalConfig(oldRetrievalConfig, filteredDataSets, dataSet, { | |||
| provider: currentRerankProvider?.provider, | |||
| model: currentRerankModel?.model, | |||
| }) | |||
| setDatasetConfigs({ | |||
| ...(datasetConfigs as any), | |||
| ...datasetConfigsRef.current, | |||
| ...retrievalConfig, | |||
| reranking_model: { | |||
| reranking_provider_name: retrievalConfig?.reranking_model?.provider || '', | |||
| reranking_model_name: retrievalConfig?.reranking_model?.model || '', | |||
| }, | |||
| retrieval_model, | |||
| score_threshold_enabled, | |||
| datasets, | |||
| }) | |||
| const { | |||
| allExternal, | |||
| @@ -30,11 +30,11 @@ import { noop } from 'lodash-es' | |||
| type Props = { | |||
| datasetConfigs: DatasetConfigs | |||
| onChange: (configs: DatasetConfigs, isRetrievalModeChange?: boolean) => void | |||
| selectedDatasets?: DataSet[] | |||
| isInWorkflow?: boolean | |||
| singleRetrievalModelConfig?: ModelConfig | |||
| onSingleRetrievalModelChange?: (config: ModelConfig) => void | |||
| onSingleRetrievalModelParamsChange?: (config: ModelConfig) => void | |||
| selectedDatasets?: DataSet[] | |||
| } | |||
| const ConfigContent: FC<Props> = ({ | |||
| @@ -61,22 +61,28 @@ const ConfigContent: FC<Props> = ({ | |||
| const { | |||
| modelList: rerankModelList, | |||
| currentModel: validDefaultRerankModel, | |||
| currentProvider: validDefaultRerankProvider, | |||
| } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank) | |||
| /** | |||
| * If reranking model is set and is valid, use the reranking model | |||
| * Otherwise, check if the default reranking model is valid | |||
| */ | |||
| const { | |||
| currentModel: currentRerankModel, | |||
| } = useCurrentProviderAndModel( | |||
| rerankModelList, | |||
| { | |||
| provider: datasetConfigs.reranking_model?.reranking_provider_name, | |||
| model: datasetConfigs.reranking_model?.reranking_model_name, | |||
| provider: datasetConfigs.reranking_model?.reranking_provider_name || validDefaultRerankProvider?.provider || '', | |||
| model: datasetConfigs.reranking_model?.reranking_model_name || validDefaultRerankModel?.model || '', | |||
| }, | |||
| ) | |||
| const rerankModel = useMemo(() => { | |||
| return { | |||
| provider_name: datasetConfigs?.reranking_model?.reranking_provider_name ?? '', | |||
| model_name: datasetConfigs?.reranking_model?.reranking_model_name ?? '', | |||
| provider_name: datasetConfigs.reranking_model?.reranking_provider_name ?? '', | |||
| model_name: datasetConfigs.reranking_model?.reranking_model_name ?? '', | |||
| } | |||
| }, [datasetConfigs.reranking_model]) | |||
| @@ -135,7 +141,7 @@ const ConfigContent: FC<Props> = ({ | |||
| }) | |||
| } | |||
| const model = singleRetrievalConfig | |||
| const model = singleRetrievalConfig // Legacy code, for compatibility, have to keep it | |||
| const rerankingModeOptions = [ | |||
| { | |||
| @@ -158,7 +164,7 @@ const ConfigContent: FC<Props> = ({ | |||
| const canManuallyToggleRerank = useMemo(() => { | |||
| return (selectedDatasetsMode.allInternal && selectedDatasetsMode.allEconomic) | |||
| || selectedDatasetsMode.allExternal | |||
| || selectedDatasetsMode.allExternal | |||
| }, [selectedDatasetsMode.allEconomic, selectedDatasetsMode.allExternal, selectedDatasetsMode.allInternal]) | |||
| const showRerankModel = useMemo(() => { | |||
| @@ -168,7 +174,7 @@ const ConfigContent: FC<Props> = ({ | |||
| return datasetConfigs.reranking_enable | |||
| }, [datasetConfigs.reranking_enable, canManuallyToggleRerank]) | |||
| const handleDisabledSwitchClick = useCallback((enable: boolean) => { | |||
| const handleManuallyToggleRerank = useCallback((enable: boolean) => { | |||
| if (!currentRerankModel && enable) | |||
| Toast.notify({ type: 'error', message: t('workflow.errorMsg.rerankModelRequired') }) | |||
| onChange({ | |||
| @@ -255,12 +261,11 @@ const ConfigContent: FC<Props> = ({ | |||
| <div className='mt-2'> | |||
| <div className='flex items-center'> | |||
| { | |||
| selectedDatasetsMode.allEconomic && !selectedDatasetsMode.mixtureInternalAndExternal && ( | |||
| canManuallyToggleRerank && ( | |||
| <Switch | |||
| size='md' | |||
| defaultValue={showRerankModel} | |||
| disabled={!canManuallyToggleRerank} | |||
| onChange={handleDisabledSwitchClick} | |||
| onChange={handleManuallyToggleRerank} | |||
| /> | |||
| ) | |||
| } | |||
| @@ -284,18 +284,28 @@ const Configuration: FC = () => { | |||
| setRerankSettingModalOpen(true) | |||
| const { datasets, retrieval_model, score_threshold_enabled, ...restConfigs } = datasetConfigs | |||
| const { | |||
| top_k, | |||
| score_threshold, | |||
| reranking_model, | |||
| reranking_mode, | |||
| weights, | |||
| reranking_enable, | |||
| } = restConfigs | |||
| const oldRetrievalConfig = { | |||
| top_k, | |||
| score_threshold, | |||
| reranking_model: (reranking_model.reranking_provider_name && reranking_model.reranking_model_name) ? { | |||
| provider: reranking_model.reranking_provider_name, | |||
| model: reranking_model.reranking_model_name, | |||
| } : undefined, | |||
| reranking_mode, | |||
| weights, | |||
| reranking_enable, | |||
| } | |||
| const retrievalConfig = getMultipleRetrievalConfig({ | |||
| top_k: restConfigs.top_k, | |||
| score_threshold: restConfigs.score_threshold, | |||
| reranking_model: restConfigs.reranking_model && { | |||
| provider: restConfigs.reranking_model.reranking_provider_name, | |||
| model: restConfigs.reranking_model.reranking_model_name, | |||
| }, | |||
| reranking_mode: restConfigs.reranking_mode, | |||
| weights: restConfigs.weights, | |||
| reranking_enable: restConfigs.reranking_enable, | |||
| }, newDatasets, dataSets, { | |||
| const retrievalConfig = getMultipleRetrievalConfig(oldRetrievalConfig, newDatasets, dataSets, { | |||
| provider: currentRerankProvider?.provider, | |||
| model: currentRerankModel?.model, | |||
| }) | |||
| @@ -40,7 +40,7 @@ const RetrievalMethodConfig: FC<Props> = ({ | |||
| onChange({ | |||
| ...value, | |||
| search_method: retrieveMethod, | |||
| ...(!value.reranking_model.reranking_model_name | |||
| ...((!value.reranking_model.reranking_model_name || !value.reranking_model.reranking_provider_name) | |||
| ? { | |||
| reranking_model: { | |||
| reranking_provider_name: isRerankDefaultModelValid ? rerankDefaultModel?.provider?.provider ?? '' : '', | |||
| @@ -57,7 +57,7 @@ const RetrievalMethodConfig: FC<Props> = ({ | |||
| onChange({ | |||
| ...value, | |||
| search_method: retrieveMethod, | |||
| ...(!value.reranking_model.reranking_model_name | |||
| ...((!value.reranking_model.reranking_model_name || !value.reranking_model.reranking_provider_name) | |||
| ? { | |||
| reranking_model: { | |||
| reranking_provider_name: isRerankDefaultModelValid ? rerankDefaultModel?.provider?.provider ?? '' : '', | |||
| @@ -54,7 +54,7 @@ const RetrievalParamConfig: FC<Props> = ({ | |||
| }, | |||
| ) | |||
| const handleDisabledSwitchClick = useCallback((enable: boolean) => { | |||
| const handleToggleRerankEnable = useCallback((enable: boolean) => { | |||
| if (enable && !currentModel) | |||
| Toast.notify({ type: 'error', message: t('workflow.errorMsg.rerankModelRequired') }) | |||
| onChange({ | |||
| @@ -119,7 +119,7 @@ const RetrievalParamConfig: FC<Props> = ({ | |||
| <Switch | |||
| size='md' | |||
| defaultValue={value.reranking_enable} | |||
| onChange={handleDisabledSwitchClick} | |||
| onChange={handleToggleRerankEnable} | |||
| /> | |||
| )} | |||
| <div className='flex items-center'> | |||
| @@ -1,6 +1,6 @@ | |||
| 'use client' | |||
| import type { FC } from 'react' | |||
| import React, { useCallback, useState } from 'react' | |||
| import React, { useCallback, useMemo } from 'react' | |||
| import { RiEqualizer2Line } from '@remixicon/react' | |||
| import { useTranslation } from 'react-i18next' | |||
| import type { MultipleRetrievalConfig, SingleRetrievalConfig } from '../types' | |||
| @@ -14,8 +14,6 @@ import { | |||
| import ConfigRetrievalContent from '@/app/components/app/configuration/dataset-config/params-config/config-content' | |||
| import { RETRIEVE_TYPE } from '@/types/app' | |||
| import { DATASET_DEFAULT } from '@/config' | |||
| import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks' | |||
| import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' | |||
| import Button from '@/app/components/base/button' | |||
| import type { DatasetConfigs } from '@/models/debug' | |||
| import type { DataSet } from '@/models/datasets' | |||
| @@ -32,8 +30,8 @@ type Props = { | |||
| onSingleRetrievalModelChange?: (config: ModelConfig) => void | |||
| onSingleRetrievalModelParamsChange?: (config: ModelConfig) => void | |||
| readonly?: boolean | |||
| openFromProps?: boolean | |||
| onOpenFromPropsChange?: (openFromProps: boolean) => void | |||
| rerankModalOpen: boolean | |||
| onRerankModelOpenChange: (open: boolean) => void | |||
| selectedDatasets: DataSet[] | |||
| } | |||
| @@ -45,26 +43,52 @@ const RetrievalConfig: FC<Props> = ({ | |||
| onSingleRetrievalModelChange, | |||
| onSingleRetrievalModelParamsChange, | |||
| readonly, | |||
| openFromProps, | |||
| onOpenFromPropsChange, | |||
| rerankModalOpen, | |||
| onRerankModelOpenChange, | |||
| selectedDatasets, | |||
| }) => { | |||
| const { t } = useTranslation() | |||
| const [open, setOpen] = useState(false) | |||
| const mergedOpen = openFromProps !== undefined ? openFromProps : open | |||
| const { retrieval_mode, multiple_retrieval_config } = payload | |||
| const handleOpen = useCallback((newOpen: boolean) => { | |||
| setOpen(newOpen) | |||
| onOpenFromPropsChange?.(newOpen) | |||
| }, [onOpenFromPropsChange]) | |||
| onRerankModelOpenChange(newOpen) | |||
| }, [onRerankModelOpenChange]) | |||
| const { | |||
| currentProvider: validRerankDefaultProvider, | |||
| currentModel: validRerankDefaultModel, | |||
| } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank) | |||
| const datasetConfigs = useMemo(() => { | |||
| const { | |||
| reranking_model, | |||
| top_k, | |||
| score_threshold, | |||
| reranking_mode, | |||
| weights, | |||
| reranking_enable, | |||
| } = multiple_retrieval_config || {} | |||
| return { | |||
| retrieval_model: retrieval_mode, | |||
| reranking_model: (reranking_model?.provider && reranking_model?.model) | |||
| ? { | |||
| reranking_provider_name: reranking_model?.provider, | |||
| reranking_model_name: reranking_model?.model, | |||
| } | |||
| : { | |||
| reranking_provider_name: '', | |||
| reranking_model_name: '', | |||
| }, | |||
| top_k: top_k || DATASET_DEFAULT.top_k, | |||
| score_threshold_enabled: !(score_threshold === undefined || score_threshold === null), | |||
| score_threshold, | |||
| datasets: { | |||
| datasets: [], | |||
| }, | |||
| reranking_mode, | |||
| weights, | |||
| reranking_enable, | |||
| } | |||
| }, [retrieval_mode, multiple_retrieval_config]) | |||
| const { multiple_retrieval_config } = payload | |||
| const handleChange = useCallback((configs: DatasetConfigs, isRetrievalModeChange?: boolean) => { | |||
| // Legacy code, for compatibility, have to keep it | |||
| if (isRetrievalModeChange) { | |||
| onRetrievalModeChange(configs.retrieval_model) | |||
| return | |||
| @@ -72,13 +96,11 @@ const RetrievalConfig: FC<Props> = ({ | |||
| onMultipleRetrievalConfigChange({ | |||
| top_k: configs.top_k, | |||
| score_threshold: configs.score_threshold_enabled ? (configs.score_threshold ?? DATASET_DEFAULT.score_threshold) : null, | |||
| reranking_model: payload.retrieval_mode === RETRIEVE_TYPE.oneWay | |||
| reranking_model: retrieval_mode === RETRIEVE_TYPE.oneWay | |||
| ? undefined | |||
| // eslint-disable-next-line sonarjs/no-nested-conditional | |||
| : (!configs.reranking_model?.reranking_provider_name | |||
| ? { | |||
| provider: validRerankDefaultProvider?.provider || '', | |||
| model: validRerankDefaultModel?.model || '', | |||
| } | |||
| ? undefined | |||
| : { | |||
| provider: configs.reranking_model?.reranking_provider_name, | |||
| model: configs.reranking_model?.reranking_model_name, | |||
| @@ -87,11 +109,11 @@ const RetrievalConfig: FC<Props> = ({ | |||
| weights: configs.weights, | |||
| reranking_enable: configs.reranking_enable, | |||
| }) | |||
| }, [onMultipleRetrievalConfigChange, payload.retrieval_mode, validRerankDefaultProvider, validRerankDefaultModel, onRetrievalModeChange]) | |||
| }, [onMultipleRetrievalConfigChange, retrieval_mode, onRetrievalModeChange]) | |||
| return ( | |||
| <PortalToFollowElem | |||
| open={mergedOpen} | |||
| open={rerankModalOpen} | |||
| onOpenChange={handleOpen} | |||
| placement='bottom-end' | |||
| offset={{ | |||
| @@ -102,14 +124,14 @@ const RetrievalConfig: FC<Props> = ({ | |||
| onClick={() => { | |||
| if (readonly) | |||
| return | |||
| handleOpen(!mergedOpen) | |||
| handleOpen(!rerankModalOpen) | |||
| }} | |||
| > | |||
| <Button | |||
| variant='ghost' | |||
| size='small' | |||
| disabled={readonly} | |||
| className={cn(open && 'bg-components-button-ghost-bg-hover')} | |||
| className={cn(rerankModalOpen && 'bg-components-button-ghost-bg-hover')} | |||
| > | |||
| <RiEqualizer2Line className='mr-1 h-3.5 w-3.5' /> | |||
| {t('dataset.retrievalSettings')} | |||
| @@ -118,35 +140,13 @@ const RetrievalConfig: FC<Props> = ({ | |||
| <PortalToFollowElemContent style={{ zIndex: 1001 }}> | |||
| <div className='w-[404px] rounded-2xl border border-components-panel-border bg-components-panel-bg px-4 pb-4 pt-3 shadow-xl'> | |||
| <ConfigRetrievalContent | |||
| datasetConfigs={ | |||
| { | |||
| retrieval_model: payload.retrieval_mode, | |||
| reranking_model: multiple_retrieval_config?.reranking_model?.provider | |||
| ? { | |||
| reranking_provider_name: multiple_retrieval_config.reranking_model?.provider, | |||
| reranking_model_name: multiple_retrieval_config.reranking_model?.model, | |||
| } | |||
| : { | |||
| reranking_provider_name: '', | |||
| reranking_model_name: '', | |||
| }, | |||
| top_k: multiple_retrieval_config?.top_k || DATASET_DEFAULT.top_k, | |||
| score_threshold_enabled: !(multiple_retrieval_config?.score_threshold === undefined || multiple_retrieval_config.score_threshold === null), | |||
| score_threshold: multiple_retrieval_config?.score_threshold, | |||
| datasets: { | |||
| datasets: [], | |||
| }, | |||
| reranking_mode: multiple_retrieval_config?.reranking_mode, | |||
| weights: multiple_retrieval_config?.weights, | |||
| reranking_enable: multiple_retrieval_config?.reranking_enable, | |||
| } | |||
| } | |||
| datasetConfigs={datasetConfigs} | |||
| onChange={handleChange} | |||
| selectedDatasets={selectedDatasets} | |||
| isInWorkflow | |||
| singleRetrievalModelConfig={singleRetrievalModelConfig} | |||
| onSingleRetrievalModelChange={onSingleRetrievalModelChange} | |||
| onSingleRetrievalModelParamsChange={onSingleRetrievalModelParamsChange} | |||
| selectedDatasets={selectedDatasets} | |||
| /> | |||
| </div> | |||
| </PortalToFollowElemContent> | |||
| @@ -1,6 +1,6 @@ | |||
| import type { NodeDefault } from '../../types' | |||
| import type { KnowledgeRetrievalNodeType } from './types' | |||
| import { checkoutRerankModelConfigedInRetrievalSettings } from './utils' | |||
| import { checkoutRerankModelConfiguredInRetrievalSettings } from './utils' | |||
| import { DATASET_DEFAULT } from '@/config' | |||
| import { RETRIEVE_TYPE } from '@/types/app' | |||
| import { genNodeMetaData } from '@/app/components/workflow/utils' | |||
| @@ -36,7 +36,7 @@ const nodeDefault: NodeDefault<KnowledgeRetrievalNodeType> = { | |||
| const { _datasets, multiple_retrieval_config, retrieval_mode } = payload | |||
| if (retrieval_mode === RETRIEVE_TYPE.multiWay) { | |||
| const checked = checkoutRerankModelConfigedInRetrievalSettings(_datasets || [], multiple_retrieval_config) | |||
| const checked = checkoutRerankModelConfiguredInRetrievalSettings(_datasets || [], multiple_retrieval_config) | |||
| if (!errorMessages && !checked) | |||
| errorMessages = t(`${i18nPrefix}.errorMsg.fieldRequired`, { field: t(`${i18nPrefix}.errorMsg.fields.rerankModel`) }) | |||
| @@ -1,7 +1,6 @@ | |||
| import type { FC } from 'react' | |||
| import { | |||
| memo, | |||
| useCallback, | |||
| useMemo, | |||
| } from 'react' | |||
| import { intersectionBy } from 'lodash-es' | |||
| @@ -53,10 +52,6 @@ const Panel: FC<NodePanelProps<KnowledgeRetrievalNodeType>> = ({ | |||
| availableNumberNodesWithParent, | |||
| } = useConfig(id, data) | |||
| const handleOpenFromPropsChange = useCallback((openFromProps: boolean) => { | |||
| setRerankModelOpen(openFromProps) | |||
| }, [setRerankModelOpen]) | |||
| const metadataList = useMemo(() => { | |||
| return intersectionBy(...selectedDatasets.filter((dataset) => { | |||
| return !!dataset.doc_metadata | |||
| @@ -68,7 +63,6 @@ const Panel: FC<NodePanelProps<KnowledgeRetrievalNodeType>> = ({ | |||
| return ( | |||
| <div className='pt-2'> | |||
| <div className='space-y-4 px-4 pb-2'> | |||
| {/* {JSON.stringify(inputs, null, 2)} */} | |||
| <Field | |||
| title={t(`${i18nPrefix}.queryVariable`)} | |||
| required | |||
| @@ -100,8 +94,8 @@ const Panel: FC<NodePanelProps<KnowledgeRetrievalNodeType>> = ({ | |||
| onSingleRetrievalModelChange={handleModelChanged as any} | |||
| onSingleRetrievalModelParamsChange={handleCompletionParamsChange} | |||
| readonly={readOnly || !selectedDatasets.length} | |||
| openFromProps={rerankModelOpen} | |||
| onOpenFromPropsChange={handleOpenFromPropsChange} | |||
| rerankModalOpen={rerankModelOpen} | |||
| onRerankModelOpenChange={setRerankModelOpen} | |||
| selectedDatasets={selectedDatasets} | |||
| /> | |||
| {!readOnly && (<div className='h-3 w-px bg-divider-regular'></div>)} | |||
| @@ -204,10 +204,11 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => { | |||
| const handleMultipleRetrievalConfigChange = useCallback((newConfig: MultipleRetrievalConfig) => { | |||
| const newInputs = produce(inputs, (draft) => { | |||
| draft.multiple_retrieval_config = getMultipleRetrievalConfig(newConfig!, selectedDatasets, selectedDatasets, { | |||
| const newMultipleRetrievalConfig = getMultipleRetrievalConfig(newConfig!, selectedDatasets, selectedDatasets, { | |||
| provider: currentRerankProvider?.provider, | |||
| model: currentRerankModel?.model, | |||
| }) | |||
| draft.multiple_retrieval_config = newMultipleRetrievalConfig | |||
| }) | |||
| setInputs(newInputs) | |||
| }, [inputs, setInputs, selectedDatasets, currentRerankModel, currentRerankProvider]) | |||
| @@ -254,10 +255,11 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => { | |||
| if (payload.retrieval_mode === RETRIEVE_TYPE.multiWay && newDatasets.length > 0) { | |||
| const multipleRetrievalConfig = draft.multiple_retrieval_config | |||
| draft.multiple_retrieval_config = getMultipleRetrievalConfig(multipleRetrievalConfig!, newDatasets, selectedDatasets, { | |||
| const newMultipleRetrievalConfig = getMultipleRetrievalConfig(multipleRetrievalConfig!, newDatasets, selectedDatasets, { | |||
| provider: currentRerankProvider?.provider, | |||
| model: currentRerankModel?.model, | |||
| }) | |||
| draft.multiple_retrieval_config = newMultipleRetrievalConfig | |||
| } | |||
| }) | |||
| updateDatasetsDetail(newDatasets) | |||
| @@ -10,6 +10,7 @@ import type { | |||
| import { | |||
| DEFAULT_WEIGHTED_SCORE, | |||
| RerankingModeEnum, | |||
| WeightedScoreEnum, | |||
| } from '@/models/datasets' | |||
| import { RETRIEVE_METHOD } from '@/types/app' | |||
| import { DATASET_DEFAULT } from '@/config' | |||
| @@ -93,10 +94,12 @@ export const getMultipleRetrievalConfig = ( | |||
| multipleRetrievalConfig: MultipleRetrievalConfig, | |||
| selectedDatasets: DataSet[], | |||
| originalDatasets: DataSet[], | |||
| validRerankModel?: { provider?: string; model?: string }, | |||
| fallbackRerankModel?: { provider?: string; model?: string }, // fallback rerank model | |||
| ) => { | |||
| const shouldSetWeightDefaultValue = xorBy(selectedDatasets, originalDatasets, 'id').length > 0 | |||
| const rerankModelIsValid = validRerankModel?.provider && validRerankModel?.model | |||
| // Check if the selected datasets are different from the original datasets | |||
| const isDatasetsChanged = xorBy(selectedDatasets, originalDatasets, 'id').length > 0 | |||
| // Check if the rerank model is valid | |||
| const isFallbackRerankModelValid = !!(fallbackRerankModel?.provider && fallbackRerankModel?.model) | |||
| const { | |||
| allHighQuality, | |||
| @@ -125,14 +128,16 @@ export const getMultipleRetrievalConfig = ( | |||
| reranking_mode, | |||
| reranking_model, | |||
| weights, | |||
| reranking_enable: ((allInternal && allEconomic) || allExternal) ? reranking_enable : shouldSetWeightDefaultValue, | |||
| reranking_enable, | |||
| } | |||
| const setDefaultWeights = () => { | |||
| result.weights = { | |||
| weight_type: WeightedScoreEnum.Customized, | |||
| vector_setting: { | |||
| vector_weight: allHighQualityVectorSearch | |||
| ? DEFAULT_WEIGHTED_SCORE.allHighQualityVectorSearch.semantic | |||
| // eslint-disable-next-line sonarjs/no-nested-conditional | |||
| : allHighQualityFullTextSearch | |||
| ? DEFAULT_WEIGHTED_SCORE.allHighQualityFullTextSearch.semantic | |||
| : DEFAULT_WEIGHTED_SCORE.other.semantic, | |||
| @@ -142,6 +147,7 @@ export const getMultipleRetrievalConfig = ( | |||
| keyword_setting: { | |||
| keyword_weight: allHighQualityVectorSearch | |||
| ? DEFAULT_WEIGHTED_SCORE.allHighQualityVectorSearch.keyword | |||
| // eslint-disable-next-line sonarjs/no-nested-conditional | |||
| : allHighQualityFullTextSearch | |||
| ? DEFAULT_WEIGHTED_SCORE.allHighQualityFullTextSearch.keyword | |||
| : DEFAULT_WEIGHTED_SCORE.other.keyword, | |||
| @@ -149,65 +155,106 @@ export const getMultipleRetrievalConfig = ( | |||
| } | |||
| } | |||
| if (allEconomic || mixtureHighQualityAndEconomic || inconsistentEmbeddingModel || allExternal || mixtureInternalAndExternal) { | |||
| /** | |||
| * In this case, user can manually toggle reranking | |||
| * So should keep the reranking_enable value | |||
| * But the default reranking_model should be set | |||
| */ | |||
| if ((allEconomic && allInternal) || allExternal) { | |||
| result.reranking_mode = RerankingModeEnum.RerankingModel | |||
| if (!result.reranking_model?.provider || !result.reranking_model?.model) { | |||
| if (rerankModelIsValid) { | |||
| result.reranking_enable = reranking_enable !== false | |||
| result.reranking_model = { | |||
| provider: validRerankModel?.provider || '', | |||
| model: validRerankModel?.model || '', | |||
| } | |||
| } | |||
| else { | |||
| result.reranking_model = { | |||
| provider: '', | |||
| model: '', | |||
| } | |||
| // Need to check if the reranking model should be set to default when first time initialized | |||
| if ((!result.reranking_model?.provider || !result.reranking_model?.model) && isFallbackRerankModelValid) { | |||
| result.reranking_model = { | |||
| provider: fallbackRerankModel.provider || '', | |||
| model: fallbackRerankModel.model || '', | |||
| } | |||
| } | |||
| else { | |||
| result.reranking_enable = reranking_enable !== false | |||
| result.reranking_enable = reranking_enable | |||
| } | |||
| /** | |||
| * In this case, reranking_enable must be true | |||
| * And if rerank model is not set, should set the default rerank model | |||
| */ | |||
| if (mixtureHighQualityAndEconomic || inconsistentEmbeddingModel || mixtureInternalAndExternal) { | |||
| result.reranking_mode = RerankingModeEnum.RerankingModel | |||
| // Need to check if the reranking model should be set to default when first time initialized | |||
| if ((!result.reranking_model?.provider || !result.reranking_model?.model) && isFallbackRerankModelValid) { | |||
| result.reranking_model = { | |||
| provider: fallbackRerankModel.provider || '', | |||
| model: fallbackRerankModel.model || '', | |||
| } | |||
| } | |||
| result.reranking_enable = true | |||
| } | |||
| /** | |||
| * In this case, user can choose to use weighted score or rerank model | |||
| * But if the reranking_mode is not initialized, should set the default rerank model and reranking_enable to true | |||
| * and set reranking_mode to reranking_model | |||
| */ | |||
| if (allHighQuality && !inconsistentEmbeddingModel && allInternal) { | |||
| // If not initialized, check if the default rerank model is valid | |||
| if (!reranking_mode) { | |||
| if (validRerankModel?.provider && validRerankModel?.model) { | |||
| if (isFallbackRerankModelValid) { | |||
| result.reranking_mode = RerankingModeEnum.RerankingModel | |||
| result.reranking_enable = reranking_enable !== false | |||
| result.reranking_enable = true | |||
| result.reranking_model = { | |||
| provider: validRerankModel.provider, | |||
| model: validRerankModel.model, | |||
| provider: fallbackRerankModel.provider || '', | |||
| model: fallbackRerankModel.model || '', | |||
| } | |||
| } | |||
| else { | |||
| result.reranking_mode = RerankingModeEnum.WeightedScore | |||
| result.reranking_enable = false | |||
| setDefaultWeights() | |||
| } | |||
| } | |||
| if (reranking_mode === RerankingModeEnum.WeightedScore && !weights) | |||
| setDefaultWeights() | |||
| // After initialization, if datasets has no change, make sure the config has correct value | |||
| if (reranking_mode === RerankingModeEnum.WeightedScore) { | |||
| result.reranking_enable = false | |||
| if (!weights) | |||
| setDefaultWeights() | |||
| } | |||
| if (reranking_mode === RerankingModeEnum.RerankingModel) { | |||
| if ((!result.reranking_model?.provider || !result.reranking_model?.model) && isFallbackRerankModelValid) { | |||
| result.reranking_model = { | |||
| provider: fallbackRerankModel.provider || '', | |||
| model: fallbackRerankModel.model || '', | |||
| } | |||
| } | |||
| result.reranking_enable = true | |||
| } | |||
| if (reranking_mode === RerankingModeEnum.WeightedScore && weights && shouldSetWeightDefaultValue) { | |||
| if (rerankModelIsValid) { | |||
| // Need to check if reranking_mode should be set to reranking_model when datasets changed | |||
| if (reranking_mode === RerankingModeEnum.WeightedScore && weights && isDatasetsChanged) { | |||
| if ((result.reranking_model?.provider && result.reranking_model?.model) || isFallbackRerankModelValid) { | |||
| result.reranking_mode = RerankingModeEnum.RerankingModel | |||
| result.reranking_enable = reranking_enable !== false | |||
| result.reranking_enable = true | |||
| result.reranking_model = { | |||
| provider: validRerankModel.provider || '', | |||
| model: validRerankModel.model || '', | |||
| // eslint-disable-next-line sonarjs/nested-control-flow | |||
| if ((!result.reranking_model?.provider || !result.reranking_model?.model) && isFallbackRerankModelValid) { | |||
| result.reranking_model = { | |||
| provider: fallbackRerankModel.provider || '', | |||
| model: fallbackRerankModel.model || '', | |||
| } | |||
| } | |||
| } | |||
| else { | |||
| setDefaultWeights() | |||
| } | |||
| } | |||
| if (reranking_mode === RerankingModeEnum.RerankingModel && !rerankModelIsValid && shouldSetWeightDefaultValue) { | |||
| // Need to switch to weighted score when reranking model is not valid and datasets changed | |||
| if ( | |||
| reranking_mode === RerankingModeEnum.RerankingModel | |||
| && (!result.reranking_model?.provider || !result.reranking_model?.model) | |||
| && !isFallbackRerankModelValid | |||
| && isDatasetsChanged | |||
| ) { | |||
| result.reranking_mode = RerankingModeEnum.WeightedScore | |||
| result.reranking_enable = false | |||
| setDefaultWeights() | |||
| } | |||
| } | |||
| @@ -215,7 +262,7 @@ export const getMultipleRetrievalConfig = ( | |||
| return result | |||
| } | |||
| export const checkoutRerankModelConfigedInRetrievalSettings = ( | |||
| export const checkoutRerankModelConfiguredInRetrievalSettings = ( | |||
| datasets: DataSet[], | |||
| multipleRetrievalConfig?: MultipleRetrievalConfig, | |||
| ) => { | |||
| @@ -225,6 +272,7 @@ export const checkoutRerankModelConfigedInRetrievalSettings = ( | |||
| const { | |||
| allEconomic, | |||
| allExternal, | |||
| allInternal, | |||
| } = getSelectedDatasetsMode(datasets) | |||
| const { | |||
| @@ -233,12 +281,8 @@ export const checkoutRerankModelConfigedInRetrievalSettings = ( | |||
| reranking_model, | |||
| } = multipleRetrievalConfig | |||
| if (reranking_mode === RerankingModeEnum.RerankingModel && (!reranking_model?.provider || !reranking_model?.model)) { | |||
| if ((allEconomic || allExternal) && !reranking_enable) | |||
| return true | |||
| return false | |||
| } | |||
| if (reranking_mode === RerankingModeEnum.RerankingModel && (!reranking_model?.provider || !reranking_model?.model)) | |||
| return ((allEconomic && allInternal) || allExternal) && !reranking_enable | |||
| return true | |||
| } | |||