| @@ -63,7 +63,7 @@ const ConfigContent: FC<Props> = ({ | |||
| } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank) | |||
| const { | |||
| currentModel, | |||
| currentModel: currentRerankModel, | |||
| } = useCurrentProviderAndModel( | |||
| rerankModelList, | |||
| rerankDefaultModel | |||
| @@ -74,11 +74,6 @@ const ConfigContent: FC<Props> = ({ | |||
| : undefined, | |||
| ) | |||
| const handleDisabledSwitchClick = useCallback(() => { | |||
| if (!currentModel) | |||
| Toast.notify({ type: 'error', message: t('workflow.errorMsg.rerankModelRequired') }) | |||
| }, [currentModel, rerankDefaultModel, t]) | |||
| const rerankModel = (() => { | |||
| if (datasetConfigs.reranking_model?.reranking_provider_name) { | |||
| return { | |||
| @@ -164,12 +159,33 @@ const ConfigContent: FC<Props> = ({ | |||
| const showWeightedScorePanel = showWeightedScore && datasetConfigs.reranking_mode === RerankingModeEnum.WeightedScore && datasetConfigs.weights | |||
| const selectedRerankMode = datasetConfigs.reranking_mode || RerankingModeEnum.RerankingModel | |||
| const canManuallyToggleRerank = useMemo(() => { | |||
| return !( | |||
| (selectedDatasetsMode.allInternal && selectedDatasetsMode.allEconomic) | |||
| || selectedDatasetsMode.allExternal | |||
| ) | |||
| }, [selectedDatasetsMode.allEconomic, selectedDatasetsMode.allExternal, selectedDatasetsMode.allInternal]) | |||
| const showRerankModel = useMemo(() => { | |||
| if (datasetConfigs.reranking_enable === false && selectedDatasetsMode.allEconomic) | |||
| if (!canManuallyToggleRerank) | |||
| return false | |||
| return true | |||
| }, [datasetConfigs.reranking_enable, selectedDatasetsMode.allEconomic]) | |||
| return datasetConfigs.reranking_enable | |||
| }, [canManuallyToggleRerank, datasetConfigs.reranking_enable]) | |||
| const handleDisabledSwitchClick = useCallback(() => { | |||
| if (!currentRerankModel && !showRerankModel) | |||
| Toast.notify({ type: 'error', message: t('workflow.errorMsg.rerankModelRequired') }) | |||
| }, [currentRerankModel, showRerankModel, t]) | |||
| useEffect(() => { | |||
| if (!canManuallyToggleRerank && showRerankModel !== datasetConfigs.reranking_enable) { | |||
| onChange({ | |||
| ...datasetConfigs, | |||
| reranking_enable: showRerankModel, | |||
| }) | |||
| } | |||
| }, [canManuallyToggleRerank, showRerankModel, datasetConfigs, onChange]) | |||
| return ( | |||
| <div> | |||
| @@ -256,13 +272,15 @@ const ConfigContent: FC<Props> = ({ | |||
| > | |||
| <Switch | |||
| size='md' | |||
| defaultValue={currentModel ? showRerankModel : false} | |||
| disabled={!currentModel} | |||
| defaultValue={showRerankModel} | |||
| disabled={!currentRerankModel || !canManuallyToggleRerank} | |||
| onChange={(v) => { | |||
| onChange({ | |||
| ...datasetConfigs, | |||
| reranking_enable: v, | |||
| }) | |||
| if (canManuallyToggleRerank) { | |||
| onChange({ | |||
| ...datasetConfigs, | |||
| reranking_enable: v, | |||
| }) | |||
| } | |||
| }} | |||
| /> | |||
| </div> | |||
| @@ -42,6 +42,7 @@ const ParamsConfig = ({ | |||
| allHighQuality, | |||
| allHighQualityFullTextSearch, | |||
| allHighQualityVectorSearch, | |||
| allInternal, | |||
| allExternal, | |||
| mixtureHighQualityAndEconomic, | |||
| inconsistentEmbeddingModel, | |||
| @@ -50,7 +51,7 @@ const ParamsConfig = ({ | |||
| const { datasets, retrieval_model, score_threshold_enabled, ...restConfigs } = datasetConfigs | |||
| let rerankEnable = restConfigs.reranking_enable | |||
| if ((allEconomic && !restConfigs.reranking_model?.reranking_provider_name && rerankEnable === undefined) || allExternal) | |||
| if (((allInternal && allEconomic) || allExternal) && !restConfigs.reranking_model?.reranking_provider_name && rerankEnable === undefined) | |||
| rerankEnable = false | |||
| if (allEconomic || allHighQuality || allHighQualityFullTextSearch || allHighQualityVectorSearch || (allExternal && selectedDatasets.length === 1)) | |||
| @@ -1,25 +1,17 @@ | |||
| import { useCallback } from 'react' | |||
| import { useStoreApi } from 'reactflow' | |||
| import { useTranslation } from 'react-i18next' | |||
| import { useWorkflowStore } from '../store' | |||
| import { | |||
| BlockEnum, | |||
| WorkflowRunningStatus, | |||
| } from '../types' | |||
| import type { KnowledgeRetrievalNodeType } from '../nodes/knowledge-retrieval/types' | |||
| import type { Node } from '../types' | |||
| import { useWorkflow } from './use-workflow' | |||
| import { | |||
| useIsChatMode, | |||
| useNodesSyncDraft, | |||
| useWorkflowInteractions, | |||
| useWorkflowRun, | |||
| } from './index' | |||
| import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' | |||
| import { useCurrentProviderAndModel, useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks' | |||
| import { useFeaturesStore } from '@/app/components/base/features/hooks' | |||
| import KnowledgeRetrievalDefault from '@/app/components/workflow/nodes/knowledge-retrieval/default' | |||
| import Toast from '@/app/components/base/toast' | |||
| export const useWorkflowStartRun = () => { | |||
| const store = useStoreApi() | |||
| @@ -28,26 +20,7 @@ export const useWorkflowStartRun = () => { | |||
| const isChatMode = useIsChatMode() | |||
| const { handleCancelDebugAndPreviewPanel } = useWorkflowInteractions() | |||
| const { handleRun } = useWorkflowRun() | |||
| const { isFromStartNode } = useWorkflow() | |||
| const { doSyncWorkflowDraft } = useNodesSyncDraft() | |||
| const { checkValid: checkKnowledgeRetrievalValid } = KnowledgeRetrievalDefault | |||
| const { t } = useTranslation() | |||
| const { | |||
| modelList: rerankModelList, | |||
| defaultModel: rerankDefaultModel, | |||
| } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank) | |||
| const { | |||
| currentModel, | |||
| } = useCurrentProviderAndModel( | |||
| rerankModelList, | |||
| rerankDefaultModel | |||
| ? { | |||
| ...rerankDefaultModel, | |||
| provider: rerankDefaultModel.provider.provider, | |||
| } | |||
| : undefined, | |||
| ) | |||
| const handleWorkflowStartRunInWorkflow = useCallback(async () => { | |||
| const { | |||
| @@ -60,9 +33,6 @@ export const useWorkflowStartRun = () => { | |||
| const { getNodes } = store.getState() | |||
| const nodes = getNodes() | |||
| const startNode = nodes.find(node => node.data.type === BlockEnum.Start) | |||
| const knowledgeRetrievalNodes = nodes.filter((node: Node<KnowledgeRetrievalNodeType>) => | |||
| node.data.type === BlockEnum.KnowledgeRetrieval, | |||
| ) | |||
| const startVariables = startNode?.data.variables || [] | |||
| const fileSettings = featuresStore!.getState().features.file | |||
| const { | |||
| @@ -72,31 +42,6 @@ export const useWorkflowStartRun = () => { | |||
| setShowEnvPanel, | |||
| } = workflowStore.getState() | |||
| if (knowledgeRetrievalNodes.length > 0) { | |||
| for (const node of knowledgeRetrievalNodes) { | |||
| if (isFromStartNode(node.id)) { | |||
| const res = checkKnowledgeRetrievalValid(node.data, t) | |||
| if (!res.isValid || !currentModel || !rerankDefaultModel) { | |||
| const errorMessage = res.errorMessage | |||
| if (errorMessage) { | |||
| Toast.notify({ | |||
| type: 'error', | |||
| message: errorMessage, | |||
| }) | |||
| return false | |||
| } | |||
| else { | |||
| Toast.notify({ | |||
| type: 'error', | |||
| message: t('appDebug.datasetConfig.rerankModelRequired'), | |||
| }) | |||
| return false | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| setShowEnvPanel(false) | |||
| if (showDebugAndPreviewPanel) { | |||
| @@ -23,7 +23,7 @@ import type { DataSet } from '@/models/datasets' | |||
| import { fetchDatasets } from '@/service/datasets' | |||
| import useNodeCrud from '@/app/components/workflow/nodes/_base/hooks/use-node-crud' | |||
| import useOneStepRun from '@/app/components/workflow/nodes/_base/hooks/use-one-step-run' | |||
| import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks' | |||
| import { useCurrentProviderAndModel, useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks' | |||
| import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' | |||
| const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => { | |||
| @@ -34,6 +34,8 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => { | |||
| const startNodeId = startNode?.id | |||
| const { inputs, setInputs: doSetInputs } = useNodeCrud<KnowledgeRetrievalNodeType>(id, payload) | |||
| const inputRef = useRef(inputs) | |||
| const setInputs = useCallback((s: KnowledgeRetrievalNodeType) => { | |||
| const newInputs = produce(s, (draft) => { | |||
| if (s.retrieval_mode === RETRIEVE_TYPE.multiWay) | |||
| @@ -43,13 +45,9 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => { | |||
| }) | |||
| // not work in pass to draft... | |||
| doSetInputs(newInputs) | |||
| inputRef.current = newInputs | |||
| }, [doSetInputs]) | |||
| const inputRef = useRef(inputs) | |||
| useEffect(() => { | |||
| inputRef.current = inputs | |||
| }, [inputs]) | |||
| const handleQueryVarChange = useCallback((newVar: ValueSelector | string) => { | |||
| const newInputs = produce(inputs, (draft) => { | |||
| draft.query_variable_selector = newVar as ValueSelector | |||
| @@ -63,9 +61,22 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => { | |||
| } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.textGeneration) | |||
| const { | |||
| modelList: rerankModelList, | |||
| defaultModel: rerankDefaultModel, | |||
| } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank) | |||
| const { | |||
| currentModel: currentRerankModel, | |||
| } = useCurrentProviderAndModel( | |||
| rerankModelList, | |||
| rerankDefaultModel | |||
| ? { | |||
| ...rerankDefaultModel, | |||
| provider: rerankDefaultModel.provider.provider, | |||
| } | |||
| : undefined, | |||
| ) | |||
| const handleModelChanged = useCallback((model: { provider: string; modelId: string; mode?: string }) => { | |||
| const newInputs = produce(inputRef.current, (draft) => { | |||
| if (!draft.single_retrieval_config) { | |||
| @@ -110,7 +121,7 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => { | |||
| // set defaults models | |||
| useEffect(() => { | |||
| const inputs = inputRef.current | |||
| if (inputs.retrieval_mode === RETRIEVE_TYPE.multiWay && inputs.multiple_retrieval_config?.reranking_model?.provider) | |||
| if (inputs.retrieval_mode === RETRIEVE_TYPE.multiWay && inputs.multiple_retrieval_config?.reranking_model?.provider && currentRerankModel && rerankDefaultModel) | |||
| return | |||
| if (inputs.retrieval_mode === RETRIEVE_TYPE.oneWay && inputs.single_retrieval_config?.model?.provider) | |||
| @@ -130,7 +141,6 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => { | |||
| } | |||
| } | |||
| } | |||
| const multipleRetrievalConfig = draft.multiple_retrieval_config | |||
| draft.multiple_retrieval_config = { | |||
| top_k: multipleRetrievalConfig?.top_k || DATASET_DEFAULT.top_k, | |||
| @@ -138,6 +148,9 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => { | |||
| reranking_model: multipleRetrievalConfig?.reranking_model, | |||
| reranking_mode: multipleRetrievalConfig?.reranking_mode, | |||
| weights: multipleRetrievalConfig?.weights, | |||
| reranking_enable: multipleRetrievalConfig?.reranking_enable !== undefined | |||
| ? multipleRetrievalConfig.reranking_enable | |||
| : Boolean(currentRerankModel && rerankDefaultModel), | |||
| } | |||
| }) | |||
| setInputs(newInput) | |||
| @@ -194,14 +207,14 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => { | |||
| }, []) | |||
| useEffect(() => { | |||
| const inputs = inputRef.current | |||
| let query_variable_selector: ValueSelector = inputs.query_variable_selector | |||
| if (isChatMode && inputs.query_variable_selector.length === 0 && startNodeId) | |||
| query_variable_selector = [startNodeId, 'sys.query'] | |||
| setInputs({ | |||
| ...inputs, | |||
| query_variable_selector, | |||
| }) | |||
| setInputs(produce(inputs, (draft) => { | |||
| draft.query_variable_selector = query_variable_selector | |||
| })) | |||
| // eslint-disable-next-line react-hooks/exhaustive-deps | |||
| }, []) | |||
| @@ -113,7 +113,7 @@ export const getMultipleRetrievalConfig = (multipleRetrievalConfig: MultipleRetr | |||
| reranking_mode, | |||
| reranking_model, | |||
| weights, | |||
| reranking_enable: allEconomic ? reranking_enable : true, | |||
| reranking_enable: ((allInternal && allEconomic) || allExternal) ? reranking_enable : true, | |||
| } | |||
| if (allEconomic || mixtureHighQualityAndEconomic || inconsistentEmbeddingModel || allExternal || mixtureInternalAndExternal) | |||