Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>tags/1.9.1
| const onRemove = (id: string) => { | const onRemove = (id: string) => { | ||||
| const filteredDataSets = dataSet.filter(item => item.id !== id) | const filteredDataSets = dataSet.filter(item => item.id !== id) | ||||
| setDataSet(filteredDataSets) | setDataSet(filteredDataSets) | ||||
| const retrievalConfig = getMultipleRetrievalConfig(datasetConfigs as any, filteredDataSets, dataSet, { | |||||
| const { datasets, retrieval_model, score_threshold_enabled, ...restConfigs } = datasetConfigs | |||||
| const { | |||||
| top_k, | |||||
| score_threshold, | |||||
| reranking_model, | |||||
| reranking_mode, | |||||
| weights, | |||||
| reranking_enable, | |||||
| } = restConfigs | |||||
| const oldRetrievalConfig = { | |||||
| top_k, | |||||
| score_threshold, | |||||
| reranking_model: (reranking_model.reranking_provider_name && reranking_model.reranking_model_name) ? { | |||||
| provider: reranking_model.reranking_provider_name, | |||||
| model: reranking_model.reranking_model_name, | |||||
| } : undefined, | |||||
| reranking_mode, | |||||
| weights, | |||||
| reranking_enable, | |||||
| } | |||||
| const retrievalConfig = getMultipleRetrievalConfig(oldRetrievalConfig, filteredDataSets, dataSet, { | |||||
| provider: currentRerankProvider?.provider, | provider: currentRerankProvider?.provider, | ||||
| model: currentRerankModel?.model, | model: currentRerankModel?.model, | ||||
| }) | }) | ||||
| setDatasetConfigs({ | setDatasetConfigs({ | ||||
| ...(datasetConfigs as any), | |||||
| ...datasetConfigsRef.current, | |||||
| ...retrievalConfig, | ...retrievalConfig, | ||||
| reranking_model: { | |||||
| reranking_provider_name: retrievalConfig?.reranking_model?.provider || '', | |||||
| reranking_model_name: retrievalConfig?.reranking_model?.model || '', | |||||
| }, | |||||
| retrieval_model, | |||||
| score_threshold_enabled, | |||||
| datasets, | |||||
| }) | }) | ||||
| const { | const { | ||||
| allExternal, | allExternal, |
| type Props = { | type Props = { | ||||
| datasetConfigs: DatasetConfigs | datasetConfigs: DatasetConfigs | ||||
| onChange: (configs: DatasetConfigs, isRetrievalModeChange?: boolean) => void | onChange: (configs: DatasetConfigs, isRetrievalModeChange?: boolean) => void | ||||
| selectedDatasets?: DataSet[] | |||||
| isInWorkflow?: boolean | isInWorkflow?: boolean | ||||
| singleRetrievalModelConfig?: ModelConfig | singleRetrievalModelConfig?: ModelConfig | ||||
| onSingleRetrievalModelChange?: (config: ModelConfig) => void | onSingleRetrievalModelChange?: (config: ModelConfig) => void | ||||
| onSingleRetrievalModelParamsChange?: (config: ModelConfig) => void | onSingleRetrievalModelParamsChange?: (config: ModelConfig) => void | ||||
| selectedDatasets?: DataSet[] | |||||
| } | } | ||||
| const ConfigContent: FC<Props> = ({ | const ConfigContent: FC<Props> = ({ | ||||
| const { | const { | ||||
| modelList: rerankModelList, | modelList: rerankModelList, | ||||
| currentModel: validDefaultRerankModel, | |||||
| currentProvider: validDefaultRerankProvider, | |||||
| } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank) | } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank) | ||||
| /** | |||||
| * If reranking model is set and is valid, use the reranking model | |||||
| * Otherwise, check if the default reranking model is valid | |||||
| */ | |||||
| const { | const { | ||||
| currentModel: currentRerankModel, | currentModel: currentRerankModel, | ||||
| } = useCurrentProviderAndModel( | } = useCurrentProviderAndModel( | ||||
| rerankModelList, | rerankModelList, | ||||
| { | { | ||||
| provider: datasetConfigs.reranking_model?.reranking_provider_name, | |||||
| model: datasetConfigs.reranking_model?.reranking_model_name, | |||||
| provider: datasetConfigs.reranking_model?.reranking_provider_name || validDefaultRerankProvider?.provider || '', | |||||
| model: datasetConfigs.reranking_model?.reranking_model_name || validDefaultRerankModel?.model || '', | |||||
| }, | }, | ||||
| ) | ) | ||||
| const rerankModel = useMemo(() => { | const rerankModel = useMemo(() => { | ||||
| return { | return { | ||||
| provider_name: datasetConfigs?.reranking_model?.reranking_provider_name ?? '', | |||||
| model_name: datasetConfigs?.reranking_model?.reranking_model_name ?? '', | |||||
| provider_name: datasetConfigs.reranking_model?.reranking_provider_name ?? '', | |||||
| model_name: datasetConfigs.reranking_model?.reranking_model_name ?? '', | |||||
| } | } | ||||
| }, [datasetConfigs.reranking_model]) | }, [datasetConfigs.reranking_model]) | ||||
| }) | }) | ||||
| } | } | ||||
| const model = singleRetrievalConfig | |||||
| const model = singleRetrievalConfig // Legacy code, for compatibility, have to keep it | |||||
| const rerankingModeOptions = [ | const rerankingModeOptions = [ | ||||
| { | { | ||||
| const canManuallyToggleRerank = useMemo(() => { | const canManuallyToggleRerank = useMemo(() => { | ||||
| return (selectedDatasetsMode.allInternal && selectedDatasetsMode.allEconomic) | return (selectedDatasetsMode.allInternal && selectedDatasetsMode.allEconomic) | ||||
| || selectedDatasetsMode.allExternal | |||||
| || selectedDatasetsMode.allExternal | |||||
| }, [selectedDatasetsMode.allEconomic, selectedDatasetsMode.allExternal, selectedDatasetsMode.allInternal]) | }, [selectedDatasetsMode.allEconomic, selectedDatasetsMode.allExternal, selectedDatasetsMode.allInternal]) | ||||
| const showRerankModel = useMemo(() => { | const showRerankModel = useMemo(() => { | ||||
| return datasetConfigs.reranking_enable | return datasetConfigs.reranking_enable | ||||
| }, [datasetConfigs.reranking_enable, canManuallyToggleRerank]) | }, [datasetConfigs.reranking_enable, canManuallyToggleRerank]) | ||||
| const handleDisabledSwitchClick = useCallback((enable: boolean) => { | |||||
| const handleManuallyToggleRerank = useCallback((enable: boolean) => { | |||||
| if (!currentRerankModel && enable) | if (!currentRerankModel && enable) | ||||
| Toast.notify({ type: 'error', message: t('workflow.errorMsg.rerankModelRequired') }) | Toast.notify({ type: 'error', message: t('workflow.errorMsg.rerankModelRequired') }) | ||||
| onChange({ | onChange({ | ||||
| <div className='mt-2'> | <div className='mt-2'> | ||||
| <div className='flex items-center'> | <div className='flex items-center'> | ||||
| { | { | ||||
| selectedDatasetsMode.allEconomic && !selectedDatasetsMode.mixtureInternalAndExternal && ( | |||||
| canManuallyToggleRerank && ( | |||||
| <Switch | <Switch | ||||
| size='md' | size='md' | ||||
| defaultValue={showRerankModel} | defaultValue={showRerankModel} | ||||
| disabled={!canManuallyToggleRerank} | |||||
| onChange={handleDisabledSwitchClick} | |||||
| onChange={handleManuallyToggleRerank} | |||||
| /> | /> | ||||
| ) | ) | ||||
| } | } |
| setRerankSettingModalOpen(true) | setRerankSettingModalOpen(true) | ||||
| const { datasets, retrieval_model, score_threshold_enabled, ...restConfigs } = datasetConfigs | const { datasets, retrieval_model, score_threshold_enabled, ...restConfigs } = datasetConfigs | ||||
| const { | |||||
| top_k, | |||||
| score_threshold, | |||||
| reranking_model, | |||||
| reranking_mode, | |||||
| weights, | |||||
| reranking_enable, | |||||
| } = restConfigs | |||||
| const oldRetrievalConfig = { | |||||
| top_k, | |||||
| score_threshold, | |||||
| reranking_model: (reranking_model.reranking_provider_name && reranking_model.reranking_model_name) ? { | |||||
| provider: reranking_model.reranking_provider_name, | |||||
| model: reranking_model.reranking_model_name, | |||||
| } : undefined, | |||||
| reranking_mode, | |||||
| weights, | |||||
| reranking_enable, | |||||
| } | |||||
| const retrievalConfig = getMultipleRetrievalConfig({ | |||||
| top_k: restConfigs.top_k, | |||||
| score_threshold: restConfigs.score_threshold, | |||||
| reranking_model: restConfigs.reranking_model && { | |||||
| provider: restConfigs.reranking_model.reranking_provider_name, | |||||
| model: restConfigs.reranking_model.reranking_model_name, | |||||
| }, | |||||
| reranking_mode: restConfigs.reranking_mode, | |||||
| weights: restConfigs.weights, | |||||
| reranking_enable: restConfigs.reranking_enable, | |||||
| }, newDatasets, dataSets, { | |||||
| const retrievalConfig = getMultipleRetrievalConfig(oldRetrievalConfig, newDatasets, dataSets, { | |||||
| provider: currentRerankProvider?.provider, | provider: currentRerankProvider?.provider, | ||||
| model: currentRerankModel?.model, | model: currentRerankModel?.model, | ||||
| }) | }) |
| onChange({ | onChange({ | ||||
| ...value, | ...value, | ||||
| search_method: retrieveMethod, | search_method: retrieveMethod, | ||||
| ...(!value.reranking_model.reranking_model_name | |||||
| ...((!value.reranking_model.reranking_model_name || !value.reranking_model.reranking_provider_name) | |||||
| ? { | ? { | ||||
| reranking_model: { | reranking_model: { | ||||
| reranking_provider_name: isRerankDefaultModelValid ? rerankDefaultModel?.provider?.provider ?? '' : '', | reranking_provider_name: isRerankDefaultModelValid ? rerankDefaultModel?.provider?.provider ?? '' : '', | ||||
| onChange({ | onChange({ | ||||
| ...value, | ...value, | ||||
| search_method: retrieveMethod, | search_method: retrieveMethod, | ||||
| ...(!value.reranking_model.reranking_model_name | |||||
| ...((!value.reranking_model.reranking_model_name || !value.reranking_model.reranking_provider_name) | |||||
| ? { | ? { | ||||
| reranking_model: { | reranking_model: { | ||||
| reranking_provider_name: isRerankDefaultModelValid ? rerankDefaultModel?.provider?.provider ?? '' : '', | reranking_provider_name: isRerankDefaultModelValid ? rerankDefaultModel?.provider?.provider ?? '' : '', |
| }, | }, | ||||
| ) | ) | ||||
| const handleDisabledSwitchClick = useCallback((enable: boolean) => { | |||||
| const handleToggleRerankEnable = useCallback((enable: boolean) => { | |||||
| if (enable && !currentModel) | if (enable && !currentModel) | ||||
| Toast.notify({ type: 'error', message: t('workflow.errorMsg.rerankModelRequired') }) | Toast.notify({ type: 'error', message: t('workflow.errorMsg.rerankModelRequired') }) | ||||
| onChange({ | onChange({ | ||||
| <Switch | <Switch | ||||
| size='md' | size='md' | ||||
| defaultValue={value.reranking_enable} | defaultValue={value.reranking_enable} | ||||
| onChange={handleDisabledSwitchClick} | |||||
| onChange={handleToggleRerankEnable} | |||||
| /> | /> | ||||
| )} | )} | ||||
| <div className='flex items-center'> | <div className='flex items-center'> |
| 'use client' | 'use client' | ||||
| import type { FC } from 'react' | import type { FC } from 'react' | ||||
| import React, { useCallback, useState } from 'react' | |||||
| import React, { useCallback, useMemo } from 'react' | |||||
| import { RiEqualizer2Line } from '@remixicon/react' | import { RiEqualizer2Line } from '@remixicon/react' | ||||
| import { useTranslation } from 'react-i18next' | import { useTranslation } from 'react-i18next' | ||||
| import type { MultipleRetrievalConfig, SingleRetrievalConfig } from '../types' | import type { MultipleRetrievalConfig, SingleRetrievalConfig } from '../types' | ||||
| import ConfigRetrievalContent from '@/app/components/app/configuration/dataset-config/params-config/config-content' | import ConfigRetrievalContent from '@/app/components/app/configuration/dataset-config/params-config/config-content' | ||||
| import { RETRIEVE_TYPE } from '@/types/app' | import { RETRIEVE_TYPE } from '@/types/app' | ||||
| import { DATASET_DEFAULT } from '@/config' | import { DATASET_DEFAULT } from '@/config' | ||||
| import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks' | |||||
| import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' | |||||
| import Button from '@/app/components/base/button' | import Button from '@/app/components/base/button' | ||||
| import type { DatasetConfigs } from '@/models/debug' | import type { DatasetConfigs } from '@/models/debug' | ||||
| import type { DataSet } from '@/models/datasets' | import type { DataSet } from '@/models/datasets' | ||||
| onSingleRetrievalModelChange?: (config: ModelConfig) => void | onSingleRetrievalModelChange?: (config: ModelConfig) => void | ||||
| onSingleRetrievalModelParamsChange?: (config: ModelConfig) => void | onSingleRetrievalModelParamsChange?: (config: ModelConfig) => void | ||||
| readonly?: boolean | readonly?: boolean | ||||
| openFromProps?: boolean | |||||
| onOpenFromPropsChange?: (openFromProps: boolean) => void | |||||
| rerankModalOpen: boolean | |||||
| onRerankModelOpenChange: (open: boolean) => void | |||||
| selectedDatasets: DataSet[] | selectedDatasets: DataSet[] | ||||
| } | } | ||||
| onSingleRetrievalModelChange, | onSingleRetrievalModelChange, | ||||
| onSingleRetrievalModelParamsChange, | onSingleRetrievalModelParamsChange, | ||||
| readonly, | readonly, | ||||
| openFromProps, | |||||
| onOpenFromPropsChange, | |||||
| rerankModalOpen, | |||||
| onRerankModelOpenChange, | |||||
| selectedDatasets, | selectedDatasets, | ||||
| }) => { | }) => { | ||||
| const { t } = useTranslation() | const { t } = useTranslation() | ||||
| const [open, setOpen] = useState(false) | |||||
| const mergedOpen = openFromProps !== undefined ? openFromProps : open | |||||
| const { retrieval_mode, multiple_retrieval_config } = payload | |||||
| const handleOpen = useCallback((newOpen: boolean) => { | const handleOpen = useCallback((newOpen: boolean) => { | ||||
| setOpen(newOpen) | |||||
| onOpenFromPropsChange?.(newOpen) | |||||
| }, [onOpenFromPropsChange]) | |||||
| onRerankModelOpenChange(newOpen) | |||||
| }, [onRerankModelOpenChange]) | |||||
| const { | |||||
| currentProvider: validRerankDefaultProvider, | |||||
| currentModel: validRerankDefaultModel, | |||||
| } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank) | |||||
| const datasetConfigs = useMemo(() => { | |||||
| const { | |||||
| reranking_model, | |||||
| top_k, | |||||
| score_threshold, | |||||
| reranking_mode, | |||||
| weights, | |||||
| reranking_enable, | |||||
| } = multiple_retrieval_config || {} | |||||
| return { | |||||
| retrieval_model: retrieval_mode, | |||||
| reranking_model: (reranking_model?.provider && reranking_model?.model) | |||||
| ? { | |||||
| reranking_provider_name: reranking_model?.provider, | |||||
| reranking_model_name: reranking_model?.model, | |||||
| } | |||||
| : { | |||||
| reranking_provider_name: '', | |||||
| reranking_model_name: '', | |||||
| }, | |||||
| top_k: top_k || DATASET_DEFAULT.top_k, | |||||
| score_threshold_enabled: !(score_threshold === undefined || score_threshold === null), | |||||
| score_threshold, | |||||
| datasets: { | |||||
| datasets: [], | |||||
| }, | |||||
| reranking_mode, | |||||
| weights, | |||||
| reranking_enable, | |||||
| } | |||||
| }, [retrieval_mode, multiple_retrieval_config]) | |||||
| const { multiple_retrieval_config } = payload | |||||
| const handleChange = useCallback((configs: DatasetConfigs, isRetrievalModeChange?: boolean) => { | const handleChange = useCallback((configs: DatasetConfigs, isRetrievalModeChange?: boolean) => { | ||||
| // Legacy code, for compatibility, have to keep it | |||||
| if (isRetrievalModeChange) { | if (isRetrievalModeChange) { | ||||
| onRetrievalModeChange(configs.retrieval_model) | onRetrievalModeChange(configs.retrieval_model) | ||||
| return | return | ||||
| onMultipleRetrievalConfigChange({ | onMultipleRetrievalConfigChange({ | ||||
| top_k: configs.top_k, | top_k: configs.top_k, | ||||
| score_threshold: configs.score_threshold_enabled ? (configs.score_threshold ?? DATASET_DEFAULT.score_threshold) : null, | score_threshold: configs.score_threshold_enabled ? (configs.score_threshold ?? DATASET_DEFAULT.score_threshold) : null, | ||||
| reranking_model: payload.retrieval_mode === RETRIEVE_TYPE.oneWay | |||||
| reranking_model: retrieval_mode === RETRIEVE_TYPE.oneWay | |||||
| ? undefined | ? undefined | ||||
| // eslint-disable-next-line sonarjs/no-nested-conditional | |||||
| : (!configs.reranking_model?.reranking_provider_name | : (!configs.reranking_model?.reranking_provider_name | ||||
| ? { | |||||
| provider: validRerankDefaultProvider?.provider || '', | |||||
| model: validRerankDefaultModel?.model || '', | |||||
| } | |||||
| ? undefined | |||||
| : { | : { | ||||
| provider: configs.reranking_model?.reranking_provider_name, | provider: configs.reranking_model?.reranking_provider_name, | ||||
| model: configs.reranking_model?.reranking_model_name, | model: configs.reranking_model?.reranking_model_name, | ||||
| weights: configs.weights, | weights: configs.weights, | ||||
| reranking_enable: configs.reranking_enable, | reranking_enable: configs.reranking_enable, | ||||
| }) | }) | ||||
| }, [onMultipleRetrievalConfigChange, payload.retrieval_mode, validRerankDefaultProvider, validRerankDefaultModel, onRetrievalModeChange]) | |||||
| }, [onMultipleRetrievalConfigChange, retrieval_mode, onRetrievalModeChange]) | |||||
| return ( | return ( | ||||
| <PortalToFollowElem | <PortalToFollowElem | ||||
| open={mergedOpen} | |||||
| open={rerankModalOpen} | |||||
| onOpenChange={handleOpen} | onOpenChange={handleOpen} | ||||
| placement='bottom-end' | placement='bottom-end' | ||||
| offset={{ | offset={{ | ||||
| onClick={() => { | onClick={() => { | ||||
| if (readonly) | if (readonly) | ||||
| return | return | ||||
| handleOpen(!mergedOpen) | |||||
| handleOpen(!rerankModalOpen) | |||||
| }} | }} | ||||
| > | > | ||||
| <Button | <Button | ||||
| variant='ghost' | variant='ghost' | ||||
| size='small' | size='small' | ||||
| disabled={readonly} | disabled={readonly} | ||||
| className={cn(open && 'bg-components-button-ghost-bg-hover')} | |||||
| className={cn(rerankModalOpen && 'bg-components-button-ghost-bg-hover')} | |||||
| > | > | ||||
| <RiEqualizer2Line className='mr-1 h-3.5 w-3.5' /> | <RiEqualizer2Line className='mr-1 h-3.5 w-3.5' /> | ||||
| {t('dataset.retrievalSettings')} | {t('dataset.retrievalSettings')} | ||||
| <PortalToFollowElemContent style={{ zIndex: 1001 }}> | <PortalToFollowElemContent style={{ zIndex: 1001 }}> | ||||
| <div className='w-[404px] rounded-2xl border border-components-panel-border bg-components-panel-bg px-4 pb-4 pt-3 shadow-xl'> | <div className='w-[404px] rounded-2xl border border-components-panel-border bg-components-panel-bg px-4 pb-4 pt-3 shadow-xl'> | ||||
| <ConfigRetrievalContent | <ConfigRetrievalContent | ||||
| datasetConfigs={ | |||||
| { | |||||
| retrieval_model: payload.retrieval_mode, | |||||
| reranking_model: multiple_retrieval_config?.reranking_model?.provider | |||||
| ? { | |||||
| reranking_provider_name: multiple_retrieval_config.reranking_model?.provider, | |||||
| reranking_model_name: multiple_retrieval_config.reranking_model?.model, | |||||
| } | |||||
| : { | |||||
| reranking_provider_name: '', | |||||
| reranking_model_name: '', | |||||
| }, | |||||
| top_k: multiple_retrieval_config?.top_k || DATASET_DEFAULT.top_k, | |||||
| score_threshold_enabled: !(multiple_retrieval_config?.score_threshold === undefined || multiple_retrieval_config.score_threshold === null), | |||||
| score_threshold: multiple_retrieval_config?.score_threshold, | |||||
| datasets: { | |||||
| datasets: [], | |||||
| }, | |||||
| reranking_mode: multiple_retrieval_config?.reranking_mode, | |||||
| weights: multiple_retrieval_config?.weights, | |||||
| reranking_enable: multiple_retrieval_config?.reranking_enable, | |||||
| } | |||||
| } | |||||
| datasetConfigs={datasetConfigs} | |||||
| onChange={handleChange} | onChange={handleChange} | ||||
| selectedDatasets={selectedDatasets} | |||||
| isInWorkflow | isInWorkflow | ||||
| singleRetrievalModelConfig={singleRetrievalModelConfig} | singleRetrievalModelConfig={singleRetrievalModelConfig} | ||||
| onSingleRetrievalModelChange={onSingleRetrievalModelChange} | onSingleRetrievalModelChange={onSingleRetrievalModelChange} | ||||
| onSingleRetrievalModelParamsChange={onSingleRetrievalModelParamsChange} | onSingleRetrievalModelParamsChange={onSingleRetrievalModelParamsChange} | ||||
| selectedDatasets={selectedDatasets} | |||||
| /> | /> | ||||
| </div> | </div> | ||||
| </PortalToFollowElemContent> | </PortalToFollowElemContent> |
| import type { NodeDefault } from '../../types' | import type { NodeDefault } from '../../types' | ||||
| import type { KnowledgeRetrievalNodeType } from './types' | import type { KnowledgeRetrievalNodeType } from './types' | ||||
| import { checkoutRerankModelConfigedInRetrievalSettings } from './utils' | |||||
| import { checkoutRerankModelConfiguredInRetrievalSettings } from './utils' | |||||
| import { DATASET_DEFAULT } from '@/config' | import { DATASET_DEFAULT } from '@/config' | ||||
| import { RETRIEVE_TYPE } from '@/types/app' | import { RETRIEVE_TYPE } from '@/types/app' | ||||
| import { genNodeMetaData } from '@/app/components/workflow/utils' | import { genNodeMetaData } from '@/app/components/workflow/utils' | ||||
| const { _datasets, multiple_retrieval_config, retrieval_mode } = payload | const { _datasets, multiple_retrieval_config, retrieval_mode } = payload | ||||
| if (retrieval_mode === RETRIEVE_TYPE.multiWay) { | if (retrieval_mode === RETRIEVE_TYPE.multiWay) { | ||||
| const checked = checkoutRerankModelConfigedInRetrievalSettings(_datasets || [], multiple_retrieval_config) | |||||
| const checked = checkoutRerankModelConfiguredInRetrievalSettings(_datasets || [], multiple_retrieval_config) | |||||
| if (!errorMessages && !checked) | if (!errorMessages && !checked) | ||||
| errorMessages = t(`${i18nPrefix}.errorMsg.fieldRequired`, { field: t(`${i18nPrefix}.errorMsg.fields.rerankModel`) }) | errorMessages = t(`${i18nPrefix}.errorMsg.fieldRequired`, { field: t(`${i18nPrefix}.errorMsg.fields.rerankModel`) }) |
| import type { FC } from 'react' | import type { FC } from 'react' | ||||
| import { | import { | ||||
| memo, | memo, | ||||
| useCallback, | |||||
| useMemo, | useMemo, | ||||
| } from 'react' | } from 'react' | ||||
| import { intersectionBy } from 'lodash-es' | import { intersectionBy } from 'lodash-es' | ||||
| availableNumberNodesWithParent, | availableNumberNodesWithParent, | ||||
| } = useConfig(id, data) | } = useConfig(id, data) | ||||
| const handleOpenFromPropsChange = useCallback((openFromProps: boolean) => { | |||||
| setRerankModelOpen(openFromProps) | |||||
| }, [setRerankModelOpen]) | |||||
| const metadataList = useMemo(() => { | const metadataList = useMemo(() => { | ||||
| return intersectionBy(...selectedDatasets.filter((dataset) => { | return intersectionBy(...selectedDatasets.filter((dataset) => { | ||||
| return !!dataset.doc_metadata | return !!dataset.doc_metadata | ||||
| return ( | return ( | ||||
| <div className='pt-2'> | <div className='pt-2'> | ||||
| <div className='space-y-4 px-4 pb-2'> | <div className='space-y-4 px-4 pb-2'> | ||||
| {/* {JSON.stringify(inputs, null, 2)} */} | |||||
| <Field | <Field | ||||
| title={t(`${i18nPrefix}.queryVariable`)} | title={t(`${i18nPrefix}.queryVariable`)} | ||||
| required | required | ||||
| onSingleRetrievalModelChange={handleModelChanged as any} | onSingleRetrievalModelChange={handleModelChanged as any} | ||||
| onSingleRetrievalModelParamsChange={handleCompletionParamsChange} | onSingleRetrievalModelParamsChange={handleCompletionParamsChange} | ||||
| readonly={readOnly || !selectedDatasets.length} | readonly={readOnly || !selectedDatasets.length} | ||||
| openFromProps={rerankModelOpen} | |||||
| onOpenFromPropsChange={handleOpenFromPropsChange} | |||||
| rerankModalOpen={rerankModelOpen} | |||||
| onRerankModelOpenChange={setRerankModelOpen} | |||||
| selectedDatasets={selectedDatasets} | selectedDatasets={selectedDatasets} | ||||
| /> | /> | ||||
| {!readOnly && (<div className='h-3 w-px bg-divider-regular'></div>)} | {!readOnly && (<div className='h-3 w-px bg-divider-regular'></div>)} |
| const handleMultipleRetrievalConfigChange = useCallback((newConfig: MultipleRetrievalConfig) => { | const handleMultipleRetrievalConfigChange = useCallback((newConfig: MultipleRetrievalConfig) => { | ||||
| const newInputs = produce(inputs, (draft) => { | const newInputs = produce(inputs, (draft) => { | ||||
| draft.multiple_retrieval_config = getMultipleRetrievalConfig(newConfig!, selectedDatasets, selectedDatasets, { | |||||
| const newMultipleRetrievalConfig = getMultipleRetrievalConfig(newConfig!, selectedDatasets, selectedDatasets, { | |||||
| provider: currentRerankProvider?.provider, | provider: currentRerankProvider?.provider, | ||||
| model: currentRerankModel?.model, | model: currentRerankModel?.model, | ||||
| }) | }) | ||||
| draft.multiple_retrieval_config = newMultipleRetrievalConfig | |||||
| }) | }) | ||||
| setInputs(newInputs) | setInputs(newInputs) | ||||
| }, [inputs, setInputs, selectedDatasets, currentRerankModel, currentRerankProvider]) | }, [inputs, setInputs, selectedDatasets, currentRerankModel, currentRerankProvider]) | ||||
| if (payload.retrieval_mode === RETRIEVE_TYPE.multiWay && newDatasets.length > 0) { | if (payload.retrieval_mode === RETRIEVE_TYPE.multiWay && newDatasets.length > 0) { | ||||
| const multipleRetrievalConfig = draft.multiple_retrieval_config | const multipleRetrievalConfig = draft.multiple_retrieval_config | ||||
| draft.multiple_retrieval_config = getMultipleRetrievalConfig(multipleRetrievalConfig!, newDatasets, selectedDatasets, { | |||||
| const newMultipleRetrievalConfig = getMultipleRetrievalConfig(multipleRetrievalConfig!, newDatasets, selectedDatasets, { | |||||
| provider: currentRerankProvider?.provider, | provider: currentRerankProvider?.provider, | ||||
| model: currentRerankModel?.model, | model: currentRerankModel?.model, | ||||
| }) | }) | ||||
| draft.multiple_retrieval_config = newMultipleRetrievalConfig | |||||
| } | } | ||||
| }) | }) | ||||
| updateDatasetsDetail(newDatasets) | updateDatasetsDetail(newDatasets) |
| import { | import { | ||||
| DEFAULT_WEIGHTED_SCORE, | DEFAULT_WEIGHTED_SCORE, | ||||
| RerankingModeEnum, | RerankingModeEnum, | ||||
| WeightedScoreEnum, | |||||
| } from '@/models/datasets' | } from '@/models/datasets' | ||||
| import { RETRIEVE_METHOD } from '@/types/app' | import { RETRIEVE_METHOD } from '@/types/app' | ||||
| import { DATASET_DEFAULT } from '@/config' | import { DATASET_DEFAULT } from '@/config' | ||||
| multipleRetrievalConfig: MultipleRetrievalConfig, | multipleRetrievalConfig: MultipleRetrievalConfig, | ||||
| selectedDatasets: DataSet[], | selectedDatasets: DataSet[], | ||||
| originalDatasets: DataSet[], | originalDatasets: DataSet[], | ||||
| validRerankModel?: { provider?: string; model?: string }, | |||||
| fallbackRerankModel?: { provider?: string; model?: string }, // fallback rerank model | |||||
| ) => { | ) => { | ||||
| const shouldSetWeightDefaultValue = xorBy(selectedDatasets, originalDatasets, 'id').length > 0 | |||||
| const rerankModelIsValid = validRerankModel?.provider && validRerankModel?.model | |||||
| // Check if the selected datasets are different from the original datasets | |||||
| const isDatasetsChanged = xorBy(selectedDatasets, originalDatasets, 'id').length > 0 | |||||
| // Check if the rerank model is valid | |||||
| const isFallbackRerankModelValid = !!(fallbackRerankModel?.provider && fallbackRerankModel?.model) | |||||
| const { | const { | ||||
| allHighQuality, | allHighQuality, | ||||
| reranking_mode, | reranking_mode, | ||||
| reranking_model, | reranking_model, | ||||
| weights, | weights, | ||||
| reranking_enable: ((allInternal && allEconomic) || allExternal) ? reranking_enable : shouldSetWeightDefaultValue, | |||||
| reranking_enable, | |||||
| } | } | ||||
| const setDefaultWeights = () => { | const setDefaultWeights = () => { | ||||
| result.weights = { | result.weights = { | ||||
| weight_type: WeightedScoreEnum.Customized, | |||||
| vector_setting: { | vector_setting: { | ||||
| vector_weight: allHighQualityVectorSearch | vector_weight: allHighQualityVectorSearch | ||||
| ? DEFAULT_WEIGHTED_SCORE.allHighQualityVectorSearch.semantic | ? DEFAULT_WEIGHTED_SCORE.allHighQualityVectorSearch.semantic | ||||
| // eslint-disable-next-line sonarjs/no-nested-conditional | |||||
| : allHighQualityFullTextSearch | : allHighQualityFullTextSearch | ||||
| ? DEFAULT_WEIGHTED_SCORE.allHighQualityFullTextSearch.semantic | ? DEFAULT_WEIGHTED_SCORE.allHighQualityFullTextSearch.semantic | ||||
| : DEFAULT_WEIGHTED_SCORE.other.semantic, | : DEFAULT_WEIGHTED_SCORE.other.semantic, | ||||
| keyword_setting: { | keyword_setting: { | ||||
| keyword_weight: allHighQualityVectorSearch | keyword_weight: allHighQualityVectorSearch | ||||
| ? DEFAULT_WEIGHTED_SCORE.allHighQualityVectorSearch.keyword | ? DEFAULT_WEIGHTED_SCORE.allHighQualityVectorSearch.keyword | ||||
| // eslint-disable-next-line sonarjs/no-nested-conditional | |||||
| : allHighQualityFullTextSearch | : allHighQualityFullTextSearch | ||||
| ? DEFAULT_WEIGHTED_SCORE.allHighQualityFullTextSearch.keyword | ? DEFAULT_WEIGHTED_SCORE.allHighQualityFullTextSearch.keyword | ||||
| : DEFAULT_WEIGHTED_SCORE.other.keyword, | : DEFAULT_WEIGHTED_SCORE.other.keyword, | ||||
| } | } | ||||
| } | } | ||||
| if (allEconomic || mixtureHighQualityAndEconomic || inconsistentEmbeddingModel || allExternal || mixtureInternalAndExternal) { | |||||
| /** | |||||
| * In this case, user can manually toggle reranking | |||||
| * So should keep the reranking_enable value | |||||
| * But the default reranking_model should be set | |||||
| */ | |||||
| if ((allEconomic && allInternal) || allExternal) { | |||||
| result.reranking_mode = RerankingModeEnum.RerankingModel | result.reranking_mode = RerankingModeEnum.RerankingModel | ||||
| if (!result.reranking_model?.provider || !result.reranking_model?.model) { | |||||
| if (rerankModelIsValid) { | |||||
| result.reranking_enable = reranking_enable !== false | |||||
| result.reranking_model = { | |||||
| provider: validRerankModel?.provider || '', | |||||
| model: validRerankModel?.model || '', | |||||
| } | |||||
| } | |||||
| else { | |||||
| result.reranking_model = { | |||||
| provider: '', | |||||
| model: '', | |||||
| } | |||||
| // Need to check if the reranking model should be set to default when first time initialized | |||||
| if ((!result.reranking_model?.provider || !result.reranking_model?.model) && isFallbackRerankModelValid) { | |||||
| result.reranking_model = { | |||||
| provider: fallbackRerankModel.provider || '', | |||||
| model: fallbackRerankModel.model || '', | |||||
| } | } | ||||
| } | } | ||||
| else { | |||||
| result.reranking_enable = reranking_enable !== false | |||||
| result.reranking_enable = reranking_enable | |||||
| } | |||||
| /** | |||||
| * In this case, reranking_enable must be true | |||||
| * And if rerank model is not set, should set the default rerank model | |||||
| */ | |||||
| if (mixtureHighQualityAndEconomic || inconsistentEmbeddingModel || mixtureInternalAndExternal) { | |||||
| result.reranking_mode = RerankingModeEnum.RerankingModel | |||||
| // Need to check if the reranking model should be set to default when first time initialized | |||||
| if ((!result.reranking_model?.provider || !result.reranking_model?.model) && isFallbackRerankModelValid) { | |||||
| result.reranking_model = { | |||||
| provider: fallbackRerankModel.provider || '', | |||||
| model: fallbackRerankModel.model || '', | |||||
| } | |||||
| } | } | ||||
| result.reranking_enable = true | |||||
| } | } | ||||
| /** | |||||
| * In this case, user can choose to use weighted score or rerank model | |||||
| * But if the reranking_mode is not initialized, should set the default rerank model and reranking_enable to true | |||||
| * and set reranking_mode to reranking_model | |||||
| */ | |||||
| if (allHighQuality && !inconsistentEmbeddingModel && allInternal) { | if (allHighQuality && !inconsistentEmbeddingModel && allInternal) { | ||||
| // If not initialized, check if the default rerank model is valid | |||||
| if (!reranking_mode) { | if (!reranking_mode) { | ||||
| if (validRerankModel?.provider && validRerankModel?.model) { | |||||
| if (isFallbackRerankModelValid) { | |||||
| result.reranking_mode = RerankingModeEnum.RerankingModel | result.reranking_mode = RerankingModeEnum.RerankingModel | ||||
| result.reranking_enable = reranking_enable !== false | |||||
| result.reranking_enable = true | |||||
| result.reranking_model = { | result.reranking_model = { | ||||
| provider: validRerankModel.provider, | |||||
| model: validRerankModel.model, | |||||
| provider: fallbackRerankModel.provider || '', | |||||
| model: fallbackRerankModel.model || '', | |||||
| } | } | ||||
| } | } | ||||
| else { | else { | ||||
| result.reranking_mode = RerankingModeEnum.WeightedScore | result.reranking_mode = RerankingModeEnum.WeightedScore | ||||
| result.reranking_enable = false | |||||
| setDefaultWeights() | setDefaultWeights() | ||||
| } | } | ||||
| } | } | ||||
| if (reranking_mode === RerankingModeEnum.WeightedScore && !weights) | |||||
| setDefaultWeights() | |||||
| // After initialization, if datasets has no change, make sure the config has correct value | |||||
| if (reranking_mode === RerankingModeEnum.WeightedScore) { | |||||
| result.reranking_enable = false | |||||
| if (!weights) | |||||
| setDefaultWeights() | |||||
| } | |||||
| if (reranking_mode === RerankingModeEnum.RerankingModel) { | |||||
| if ((!result.reranking_model?.provider || !result.reranking_model?.model) && isFallbackRerankModelValid) { | |||||
| result.reranking_model = { | |||||
| provider: fallbackRerankModel.provider || '', | |||||
| model: fallbackRerankModel.model || '', | |||||
| } | |||||
| } | |||||
| result.reranking_enable = true | |||||
| } | |||||
| if (reranking_mode === RerankingModeEnum.WeightedScore && weights && shouldSetWeightDefaultValue) { | |||||
| if (rerankModelIsValid) { | |||||
| // Need to check if reranking_mode should be set to reranking_model when datasets changed | |||||
| if (reranking_mode === RerankingModeEnum.WeightedScore && weights && isDatasetsChanged) { | |||||
| if ((result.reranking_model?.provider && result.reranking_model?.model) || isFallbackRerankModelValid) { | |||||
| result.reranking_mode = RerankingModeEnum.RerankingModel | result.reranking_mode = RerankingModeEnum.RerankingModel | ||||
| result.reranking_enable = reranking_enable !== false | |||||
| result.reranking_enable = true | |||||
| result.reranking_model = { | |||||
| provider: validRerankModel.provider || '', | |||||
| model: validRerankModel.model || '', | |||||
| // eslint-disable-next-line sonarjs/nested-control-flow | |||||
| if ((!result.reranking_model?.provider || !result.reranking_model?.model) && isFallbackRerankModelValid) { | |||||
| result.reranking_model = { | |||||
| provider: fallbackRerankModel.provider || '', | |||||
| model: fallbackRerankModel.model || '', | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| else { | else { | ||||
| setDefaultWeights() | setDefaultWeights() | ||||
| } | } | ||||
| } | } | ||||
| if (reranking_mode === RerankingModeEnum.RerankingModel && !rerankModelIsValid && shouldSetWeightDefaultValue) { | |||||
| // Need to switch to weighted score when reranking model is not valid and datasets changed | |||||
| if ( | |||||
| reranking_mode === RerankingModeEnum.RerankingModel | |||||
| && (!result.reranking_model?.provider || !result.reranking_model?.model) | |||||
| && !isFallbackRerankModelValid | |||||
| && isDatasetsChanged | |||||
| ) { | |||||
| result.reranking_mode = RerankingModeEnum.WeightedScore | result.reranking_mode = RerankingModeEnum.WeightedScore | ||||
| result.reranking_enable = false | |||||
| setDefaultWeights() | setDefaultWeights() | ||||
| } | } | ||||
| } | } | ||||
| return result | return result | ||||
| } | } | ||||
| export const checkoutRerankModelConfigedInRetrievalSettings = ( | |||||
| export const checkoutRerankModelConfiguredInRetrievalSettings = ( | |||||
| datasets: DataSet[], | datasets: DataSet[], | ||||
| multipleRetrievalConfig?: MultipleRetrievalConfig, | multipleRetrievalConfig?: MultipleRetrievalConfig, | ||||
| ) => { | ) => { | ||||
| const { | const { | ||||
| allEconomic, | allEconomic, | ||||
| allExternal, | allExternal, | ||||
| allInternal, | |||||
| } = getSelectedDatasetsMode(datasets) | } = getSelectedDatasetsMode(datasets) | ||||
| const { | const { | ||||
| reranking_model, | reranking_model, | ||||
| } = multipleRetrievalConfig | } = multipleRetrievalConfig | ||||
| if (reranking_mode === RerankingModeEnum.RerankingModel && (!reranking_model?.provider || !reranking_model?.model)) { | |||||
| if ((allEconomic || allExternal) && !reranking_enable) | |||||
| return true | |||||
| return false | |||||
| } | |||||
| if (reranking_mode === RerankingModeEnum.RerankingModel && (!reranking_model?.provider || !reranking_model?.model)) | |||||
| return ((allEconomic && allInternal) || allExternal) && !reranking_enable | |||||
| return true | return true | ||||
| } | } |