Pārlūkot izejas kodu

Fix/retrieval setting weight default value (#9622)

tags/0.10.1
zxhlyh pirms 1 gada
vecāks
revīzija
ff956cb546
Revīzijas autora e-pasta adrese nav piesaistīta nevienam kontam

+ 18
- 1
web/app/components/app/configuration/dataset-config/index.tsx Parādīt failu

import ConfigContext from '@/context/debug-configuration' import ConfigContext from '@/context/debug-configuration'
import { AppType } from '@/types/app' import { AppType } from '@/types/app'
import type { DataSet } from '@/models/datasets' import type { DataSet } from '@/models/datasets'
import {
getMultipleRetrievalConfig,
} from '@/app/components/workflow/nodes/knowledge-retrieval/utils'
import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'


const Icon = ( const Icon = (
<svg width="24" height="24" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg"> <svg width="24" height="24" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
setModelConfig, setModelConfig,
showSelectDataSet, showSelectDataSet,
isAgent, isAgent,
datasetConfigs,
setDatasetConfigs,
} = useContext(ConfigContext) } = useContext(ConfigContext)
const formattingChangedDispatcher = useFormattingChangedDispatcher() const formattingChangedDispatcher = useFormattingChangedDispatcher()


const hasData = dataSet.length > 0 const hasData = dataSet.length > 0


const {
currentModel: currentRerankModel,
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)

const onRemove = (id: string) => { const onRemove = (id: string) => {
setDataSet(dataSet.filter(item => item.id !== id))
const filteredDataSets = dataSet.filter(item => item.id !== id)
setDataSet(filteredDataSets)
const retrievalConfig = getMultipleRetrievalConfig(datasetConfigs as any, filteredDataSets, dataSet, !!currentRerankModel)
setDatasetConfigs({
...(datasetConfigs as any),
...retrievalConfig,
})
formattingChangedDispatcher() formattingChangedDispatcher()
} }



+ 1
- 1
web/app/components/app/configuration/dataset-config/params-config/config-content.tsx Parādīt failu

retrieval_model: RETRIEVE_TYPE.multiWay, retrieval_model: RETRIEVE_TYPE.multiWay,
}, isInWorkflow) }, isInWorkflow)
} }
}, [type])
}, [type, datasetConfigs, isInWorkflow, onChange])


const { const {
modelList: rerankModelList, modelList: rerankModelList,

+ 4
- 53
web/app/components/app/configuration/dataset-config/params-config/index.tsx Parādīt failu

import type { DatasetConfigs } from '@/models/debug' import type { DatasetConfigs } from '@/models/debug'
import { import {
getMultipleRetrievalConfig, getMultipleRetrievalConfig,
getSelectedDatasetsMode,
} from '@/app/components/workflow/nodes/knowledge-retrieval/utils' } from '@/app/components/workflow/nodes/knowledge-retrieval/utils'


type ParamsConfigProps = { type ParamsConfigProps = {
const [tempDataSetConfigs, setTempDataSetConfigs] = useState(datasetConfigs) const [tempDataSetConfigs, setTempDataSetConfigs] = useState(datasetConfigs)


useEffect(() => { useEffect(() => {
const {
allEconomic,
allHighQuality,
allHighQualityFullTextSearch,
allHighQualityVectorSearch,
allExternal,
mixtureHighQualityAndEconomic,
inconsistentEmbeddingModel,
mixtureInternalAndExternal,
} = getSelectedDatasetsMode(selectedDatasets)

if (allEconomic || allHighQuality || allHighQualityFullTextSearch || allHighQualityVectorSearch || (allExternal && selectedDatasets.length === 1))
setRerankSettingModalOpen(false)

if (mixtureHighQualityAndEconomic || inconsistentEmbeddingModel || mixtureInternalAndExternal || (allExternal && selectedDatasets.length > 1))
setRerankSettingModalOpen(true)
}, [selectedDatasets])

useEffect(() => {
const {
allEconomic,
allInternal,
allExternal,
} = getSelectedDatasetsMode(selectedDatasets)
const { datasets, retrieval_model, score_threshold_enabled, ...restConfigs } = datasetConfigs
let rerankEnable = restConfigs.reranking_enable

if (((allInternal && allEconomic) || allExternal) && !restConfigs.reranking_model?.reranking_provider_name && rerankEnable === undefined)
rerankEnable = false

setTempDataSetConfigs({
...getMultipleRetrievalConfig({
top_k: restConfigs.top_k,
score_threshold: restConfigs.score_threshold,
reranking_model: restConfigs.reranking_model && {
provider: restConfigs.reranking_model.reranking_provider_name,
model: restConfigs.reranking_model.reranking_model_name,
},
reranking_mode: restConfigs.reranking_mode,
weights: restConfigs.weights,
reranking_enable: rerankEnable,
}, selectedDatasets),
reranking_model: restConfigs.reranking_model && {
reranking_provider_name: restConfigs.reranking_model.reranking_provider_name,
reranking_model_name: restConfigs.reranking_model.reranking_model_name,
},
retrieval_model,
score_threshold_enabled,
datasets,
})
}, [selectedDatasets, datasetConfigs])
setTempDataSetConfigs(datasetConfigs)
}, [datasetConfigs])


const { const {
defaultModel: rerankDefaultModel, defaultModel: rerankDefaultModel,
reranking_mode: restConfigs.reranking_mode, reranking_mode: restConfigs.reranking_mode,
weights: restConfigs.weights, weights: restConfigs.weights,
reranking_enable: restConfigs.reranking_enable, reranking_enable: restConfigs.reranking_enable,
}, selectedDatasets)
}, selectedDatasets, selectedDatasets, !!isRerankDefaultModelValid)


setTempDataSetConfigs({ setTempDataSetConfigs({
...retrievalConfig, ...retrievalConfig,


<div className='mt-6 flex justify-end'> <div className='mt-6 flex justify-end'>
<Button className='mr-2 flex-shrink-0' onClick={() => { <Button className='mr-2 flex-shrink-0' onClick={() => {
setTempDataSetConfigs(datasetConfigs)
setRerankSettingModalOpen(false) setRerankSettingModalOpen(false)
}}>{t('common.operation.cancel')}</Button> }}>{t('common.operation.cancel')}</Button>
<Button variant='primary' className='flex-shrink-0' onClick={handleSave} >{t('common.operation.save')}</Button> <Button variant='primary' className='flex-shrink-0' onClick={handleSave} >{t('common.operation.save')}</Button>

+ 11
- 3
web/app/components/app/configuration/index.tsx Parādīt failu

import Config from '@/app/components/app/configuration/config' import Config from '@/app/components/app/configuration/config'
import Debug from '@/app/components/app/configuration/debug' import Debug from '@/app/components/app/configuration/debug'
import Confirm from '@/app/components/base/confirm' import Confirm from '@/app/components/base/confirm'
import { ModelFeatureEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
import { ModelFeatureEnum, ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
import { ToastContext } from '@/app/components/base/toast' import { ToastContext } from '@/app/components/base/toast'
import { fetchAppDetail, updateAppModelConfig } from '@/service/apps' import { fetchAppDetail, updateAppModelConfig } from '@/service/apps'
import { promptVariablesToUserInputsForm, userInputsFormToPromptVariables } from '@/utils/model-config' import { promptVariablesToUserInputsForm, userInputsFormToPromptVariables } from '@/utils/model-config'
import Drawer from '@/app/components/base/drawer' import Drawer from '@/app/components/base/drawer'
import ModelParameterModal from '@/app/components/header/account-setting/model-provider-page/model-parameter-modal' import ModelParameterModal from '@/app/components/header/account-setting/model-provider-page/model-parameter-modal'
import type { FormValue } from '@/app/components/header/account-setting/model-provider-page/declarations' import type { FormValue } from '@/app/components/header/account-setting/model-provider-page/declarations'
import { useTextGenerationCurrentProviderAndModelAndModelList } from '@/app/components/header/account-setting/model-provider-page/hooks'
import {
useModelListAndDefaultModelAndCurrentProviderAndModel,
useTextGenerationCurrentProviderAndModelAndModelList,
} from '@/app/components/header/account-setting/model-provider-page/hooks'
import { fetchCollectionList } from '@/service/tools' import { fetchCollectionList } from '@/service/tools'
import { type Collection } from '@/app/components/tools/types' import { type Collection } from '@/app/components/tools/types'
import { useStore as useAppStore } from '@/app/components/app/store' import { useStore as useAppStore } from '@/app/components/app/store'
const [isShowSelectDataSet, { setTrue: showSelectDataSet, setFalse: hideSelectDataSet }] = useBoolean(false) const [isShowSelectDataSet, { setTrue: showSelectDataSet, setFalse: hideSelectDataSet }] = useBoolean(false)
const selectedIds = dataSets.map(item => item.id) const selectedIds = dataSets.map(item => item.id)
const [rerankSettingModalOpen, setRerankSettingModalOpen] = useState(false) const [rerankSettingModalOpen, setRerankSettingModalOpen] = useState(false)
const {
currentModel: currentRerankModel,
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
const handleSelect = (data: DataSet[]) => { const handleSelect = (data: DataSet[]) => {
if (isEqual(data.map(item => item.id), dataSets.map(item => item.id))) { if (isEqual(data.map(item => item.id), dataSets.map(item => item.id))) {
hideSelectDataSet() hideSelectDataSet()
reranking_mode: restConfigs.reranking_mode, reranking_mode: restConfigs.reranking_mode,
weights: restConfigs.weights, weights: restConfigs.weights,
reranking_enable: restConfigs.reranking_enable, reranking_enable: restConfigs.reranking_enable,
}, newDatasets)
}, newDatasets, dataSets, !!currentRerankModel)


setDatasetConfigs({ setDatasetConfigs({
...retrievalConfig, ...retrievalConfig,


syncToPublishedConfig(config) syncToPublishedConfig(config)
setPublishedConfig(config) setPublishedConfig(config)
const retrievalConfig = getMultipleRetrievalConfig(modelConfig.dataset_configs, datasets, datasets, !!currentRerankModel)
setDatasetConfigs({ setDatasetConfigs({
retrieval_model: RETRIEVE_TYPE.multiWay, retrieval_model: RETRIEVE_TYPE.multiWay,
...modelConfig.dataset_configs, ...modelConfig.dataset_configs,
...retrievalConfig,
}) })
setHasFetchedDetail(true) setHasFetchedDetail(true)
}) })

+ 6
- 6
web/app/components/workflow/nodes/knowledge-retrieval/use-config.ts Parādīt failu

draft.retrieval_mode = newMode draft.retrieval_mode = newMode
if (newMode === RETRIEVE_TYPE.multiWay) { if (newMode === RETRIEVE_TYPE.multiWay) {
const multipleRetrievalConfig = draft.multiple_retrieval_config const multipleRetrievalConfig = draft.multiple_retrieval_config
draft.multiple_retrieval_config = getMultipleRetrievalConfig(multipleRetrievalConfig!, selectedDatasets)
draft.multiple_retrieval_config = getMultipleRetrievalConfig(multipleRetrievalConfig!, selectedDatasets, selectedDatasets, !!currentRerankModel)
} }
else { else {
const hasSetModel = draft.single_retrieval_config?.model?.provider const hasSetModel = draft.single_retrieval_config?.model?.provider
} }
}) })
setInputs(newInputs) setInputs(newInputs)
}, [currentModel?.model, currentModel?.model_properties?.mode, currentProvider?.provider, inputs, setInputs, selectedDatasets])
}, [currentModel?.model, currentModel?.model_properties?.mode, currentProvider?.provider, inputs, setInputs, selectedDatasets, currentRerankModel])


const handleMultipleRetrievalConfigChange = useCallback((newConfig: MultipleRetrievalConfig) => { const handleMultipleRetrievalConfigChange = useCallback((newConfig: MultipleRetrievalConfig) => {
const newInputs = produce(inputs, (draft) => { const newInputs = produce(inputs, (draft) => {
draft.multiple_retrieval_config = getMultipleRetrievalConfig(newConfig!, selectedDatasets)
draft.multiple_retrieval_config = getMultipleRetrievalConfig(newConfig!, selectedDatasets, selectedDatasets, !!currentRerankModel)
}) })
setInputs(newInputs) setInputs(newInputs)
}, [inputs, setInputs, selectedDatasets])
}, [inputs, setInputs, selectedDatasets, currentRerankModel])


// datasets // datasets
useEffect(() => { useEffect(() => {


if (payload.retrieval_mode === RETRIEVE_TYPE.multiWay && newDatasets.length > 0) { if (payload.retrieval_mode === RETRIEVE_TYPE.multiWay && newDatasets.length > 0) {
const multipleRetrievalConfig = draft.multiple_retrieval_config const multipleRetrievalConfig = draft.multiple_retrieval_config
draft.multiple_retrieval_config = getMultipleRetrievalConfig(multipleRetrievalConfig!, newDatasets)
draft.multiple_retrieval_config = getMultipleRetrievalConfig(multipleRetrievalConfig!, newDatasets, selectedDatasets, !!currentRerankModel)
} }
}) })
setInputs(newInputs) setInputs(newInputs)
|| (allExternal && newDatasets.length > 1) || (allExternal && newDatasets.length > 1)
) )
setRerankModelOpen(true) setRerankModelOpen(true)
}, [inputs, setInputs, payload.retrieval_mode])
}, [inputs, setInputs, payload.retrieval_mode, selectedDatasets, currentRerankModel])


const filterVar = useCallback((varPayload: Var) => { const filterVar = useCallback((varPayload: Var) => {
return varPayload.type === VarType.string return varPayload.type === VarType.string

+ 46
- 3
web/app/components/workflow/nodes/knowledge-retrieval/utils.ts Parādīt failu

import { uniq } from 'lodash-es'
import {
uniq,
xorBy,
} from 'lodash-es'
import type { MultipleRetrievalConfig } from './types' import type { MultipleRetrievalConfig } from './types'
import type { import type {
DataSet, DataSet,
return true return true
} }


export const getSelectedDatasetsMode = (datasets: DataSet[]) => {
export const getSelectedDatasetsMode = (datasets: DataSet[] = []) => {
if (datasets === null)
datasets = []
let allHighQuality = true let allHighQuality = true
let allHighQualityVectorSearch = true let allHighQualityVectorSearch = true
let allHighQualityFullTextSearch = true let allHighQualityFullTextSearch = true
} as SelectedDatasetsMode } as SelectedDatasetsMode
} }


export const getMultipleRetrievalConfig = (multipleRetrievalConfig: MultipleRetrievalConfig, selectedDatasets: DataSet[]) => {
export const getMultipleRetrievalConfig = (
multipleRetrievalConfig: MultipleRetrievalConfig,
selectedDatasets: DataSet[],
originalDatasets: DataSet[],
isValidRerankModel?: boolean,
) => {
const shouldSetWeightDefaultValue = xorBy(selectedDatasets, originalDatasets, 'id').length > 0

const { const {
allHighQuality, allHighQuality,
allHighQualityVectorSearch, allHighQualityVectorSearch,
result.reranking_mode = RerankingModeEnum.WeightedScore result.reranking_mode = RerankingModeEnum.WeightedScore


if (allHighQuality && !inconsistentEmbeddingModel && (reranking_mode === RerankingModeEnum.WeightedScore || reranking_mode === undefined) && allInternal && !weights) { if (allHighQuality && !inconsistentEmbeddingModel && (reranking_mode === RerankingModeEnum.WeightedScore || reranking_mode === undefined) && allInternal && !weights) {
if (!isValidRerankModel)
result.reranking_mode = RerankingModeEnum.WeightedScore
else
result.reranking_mode = RerankingModeEnum.RerankingModel

result.weights = {
vector_setting: {
vector_weight: allHighQualityVectorSearch
? DEFAULT_WEIGHTED_SCORE.allHighQualityVectorSearch.semantic
: allHighQualityFullTextSearch
? DEFAULT_WEIGHTED_SCORE.allHighQualityFullTextSearch.semantic
: DEFAULT_WEIGHTED_SCORE.other.semantic,
embedding_provider_name: selectedDatasets[0].embedding_model_provider,
embedding_model_name: selectedDatasets[0].embedding_model,
},
keyword_setting: {
keyword_weight: allHighQualityVectorSearch
? DEFAULT_WEIGHTED_SCORE.allHighQualityVectorSearch.keyword
: allHighQualityFullTextSearch
? DEFAULT_WEIGHTED_SCORE.allHighQualityFullTextSearch.keyword
: DEFAULT_WEIGHTED_SCORE.other.keyword,
},
}
}

if (shouldSetWeightDefaultValue && allHighQuality && !inconsistentEmbeddingModel && (reranking_mode === RerankingModeEnum.WeightedScore || reranking_mode === undefined || !isValidRerankModel) && allInternal && weights) {
if (!isValidRerankModel)
result.reranking_mode = RerankingModeEnum.WeightedScore
else
result.reranking_mode = RerankingModeEnum.RerankingModel

result.weights = { result.weights = {
vector_setting: { vector_setting: {
vector_weight: allHighQualityVectorSearch vector_weight: allHighQualityVectorSearch

+ 0
- 8
web/models/datasets.ts Parādīt failu

semantic: 0, semantic: 0,
keyword: 1.0, keyword: 1.0,
}, },
semanticFirst: {
semantic: 0.7,
keyword: 0.3,
},
keywordFirst: {
semantic: 0.3,
keyword: 0.7,
},
other: { other: {
semantic: 0.7, semantic: 0.7,
keyword: 0.3, keyword: 0.3,

Notiek ielāde…
Atcelt
Saglabāt