| @@ -1,6 +1,6 @@ | |||
| 'use client' | |||
| import { memo, useEffect, useMemo } from 'react' | |||
| import { memo, useCallback, useEffect, useMemo } from 'react' | |||
| import type { FC } from 'react' | |||
| import { useTranslation } from 'react-i18next' | |||
| import WeightedScore from './weighted-score' | |||
| @@ -11,7 +11,7 @@ import type { | |||
| DatasetConfigs, | |||
| } from '@/models/debug' | |||
| import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector' | |||
| import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks' | |||
| import { useCurrentProviderAndModel, useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks' | |||
| import type { ModelConfig } from '@/app/components/workflow/types' | |||
| import ModelParameterModal from '@/app/components/header/account-setting/model-provider-page/model-parameter-modal' | |||
| import Tooltip from '@/app/components/base/tooltip' | |||
| @@ -23,6 +23,7 @@ import { RerankingModeEnum } from '@/models/datasets' | |||
| import cn from '@/utils/classnames' | |||
| import { useSelectedDatasetsMode } from '@/app/components/workflow/nodes/knowledge-retrieval/hooks' | |||
| import Switch from '@/app/components/base/switch' | |||
| import Toast from '@/app/components/base/toast' | |||
| type Props = { | |||
| datasetConfigs: DatasetConfigs | |||
| @@ -60,6 +61,24 @@ const ConfigContent: FC<Props> = ({ | |||
| modelList: rerankModelList, | |||
| defaultModel: rerankDefaultModel, | |||
| } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank) | |||
| const { | |||
| currentModel, | |||
| } = useCurrentProviderAndModel( | |||
| rerankModelList, | |||
| rerankDefaultModel | |||
| ? { | |||
| ...rerankDefaultModel, | |||
| provider: rerankDefaultModel.provider.provider, | |||
| } | |||
| : undefined, | |||
| ) | |||
| const handleDisabledSwitchClick = useCallback(() => { | |||
| if (!currentModel) | |||
| Toast.notify({ type: 'error', message: t('workflow.errorMsg.rerankModelRequired') }) | |||
| }, [currentModel, rerankDefaultModel, t]) | |||
| const rerankModel = (() => { | |||
| if (datasetConfigs.reranking_model?.reranking_provider_name) { | |||
| return { | |||
| @@ -231,16 +250,22 @@ const ConfigContent: FC<Props> = ({ | |||
| <div className='flex items-center'> | |||
| { | |||
| selectedDatasetsMode.allEconomic && ( | |||
| <Switch | |||
| size='md' | |||
| defaultValue={showRerankModel} | |||
| onChange={(v) => { | |||
| onChange({ | |||
| ...datasetConfigs, | |||
| reranking_enable: v, | |||
| }) | |||
| }} | |||
| /> | |||
| <div | |||
| className='flex items-center' | |||
| onClick={handleDisabledSwitchClick} | |||
| > | |||
| <Switch | |||
| size='md' | |||
| defaultValue={currentModel ? showRerankModel : false} | |||
| disabled={!currentModel} | |||
| onChange={(v) => { | |||
| onChange({ | |||
| ...datasetConfigs, | |||
| reranking_enable: v, | |||
| }) | |||
| }} | |||
| /> | |||
| </div> | |||
| ) | |||
| } | |||
| <div className='leading-[32px] ml-1 text-text-secondary system-sm-semibold'>{t('common.modelProvider.rerankModel.key')}</div> | |||
| @@ -1,6 +1,6 @@ | |||
| 'use client' | |||
| import type { FC } from 'react' | |||
| import React from 'react' | |||
| import React, { useCallback } from 'react' | |||
| import { useTranslation } from 'react-i18next' | |||
| import cn from '@/utils/classnames' | |||
| @@ -11,7 +11,7 @@ import Switch from '@/app/components/base/switch' | |||
| import Tooltip from '@/app/components/base/tooltip' | |||
| import type { RetrievalConfig } from '@/types/app' | |||
| import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector' | |||
| import { useModelListAndDefaultModel } from '@/app/components/header/account-setting/model-provider-page/hooks' | |||
| import { useCurrentProviderAndModel, useModelListAndDefaultModel } from '@/app/components/header/account-setting/model-provider-page/hooks' | |||
| import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' | |||
| import { | |||
| DEFAULT_WEIGHTED_SCORE, | |||
| @@ -19,6 +19,7 @@ import { | |||
| WeightedScoreEnum, | |||
| } from '@/models/datasets' | |||
| import WeightedScore from '@/app/components/app/configuration/dataset-config/params-config/weighted-score' | |||
| import Toast from '@/app/components/base/toast' | |||
| type Props = { | |||
| type: RETRIEVE_METHOD | |||
| @@ -38,6 +39,24 @@ const RetrievalParamConfig: FC<Props> = ({ | |||
| defaultModel: rerankDefaultModel, | |||
| modelList: rerankModelList, | |||
| } = useModelListAndDefaultModel(ModelTypeEnum.rerank) | |||
| const { | |||
| currentModel, | |||
| } = useCurrentProviderAndModel( | |||
| rerankModelList, | |||
| rerankDefaultModel | |||
| ? { | |||
| ...rerankDefaultModel, | |||
| provider: rerankDefaultModel.provider.provider, | |||
| } | |||
| : undefined, | |||
| ) | |||
| const handleDisabledSwitchClick = useCallback(() => { | |||
| if (!currentModel) | |||
| Toast.notify({ type: 'error', message: t('workflow.errorMsg.rerankModelRequired') }) | |||
| }, [currentModel, rerankDefaultModel, t]) | |||
| const isHybridSearch = type === RETRIEVE_METHOD.hybrid | |||
| const rerankModel = (() => { | |||
| @@ -99,16 +118,22 @@ const RetrievalParamConfig: FC<Props> = ({ | |||
| <div> | |||
| <div className='flex h-8 items-center text-[13px] font-medium text-gray-900 space-x-2'> | |||
| {canToggleRerankModalEnable && ( | |||
| <Switch | |||
| size='md' | |||
| defaultValue={value.reranking_enable} | |||
| onChange={(v) => { | |||
| onChange({ | |||
| ...value, | |||
| reranking_enable: v, | |||
| }) | |||
| }} | |||
| /> | |||
| <div | |||
| className='flex items-center' | |||
| onClick={handleDisabledSwitchClick} | |||
| > | |||
| <Switch | |||
| size='md' | |||
| defaultValue={currentModel ? value.reranking_enable : false} | |||
| onChange={(v) => { | |||
| onChange({ | |||
| ...value, | |||
| reranking_enable: v, | |||
| }) | |||
| }} | |||
| disabled={!currentModel} | |||
| /> | |||
| </div> | |||
| )} | |||
| <div className='flex items-center'> | |||
| <span className='mr-0.5'>{t('common.modelProvider.rerankModel.key')}</span> | |||
| @@ -1,17 +1,25 @@ | |||
| import { useCallback } from 'react' | |||
| import { useStoreApi } from 'reactflow' | |||
| import { useTranslation } from 'react-i18next' | |||
| import { useWorkflowStore } from '../store' | |||
| import { | |||
| BlockEnum, | |||
| WorkflowRunningStatus, | |||
| } from '../types' | |||
| import type { KnowledgeRetrievalNodeType } from '../nodes/knowledge-retrieval/types' | |||
| import type { Node } from '../types' | |||
| import { useWorkflow } from './use-workflow' | |||
| import { | |||
| useIsChatMode, | |||
| useNodesSyncDraft, | |||
| useWorkflowInteractions, | |||
| useWorkflowRun, | |||
| } from './index' | |||
| import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' | |||
| import { useCurrentProviderAndModel, useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks' | |||
| import { useFeaturesStore } from '@/app/components/base/features/hooks' | |||
| import KnowledgeRetrievalDefault from '@/app/components/workflow/nodes/knowledge-retrieval/default' | |||
| import Toast from '@/app/components/base/toast' | |||
| export const useWorkflowStartRun = () => { | |||
| const store = useStoreApi() | |||
| @@ -20,7 +28,26 @@ export const useWorkflowStartRun = () => { | |||
| const isChatMode = useIsChatMode() | |||
| const { handleCancelDebugAndPreviewPanel } = useWorkflowInteractions() | |||
| const { handleRun } = useWorkflowRun() | |||
| const { isFromStartNode } = useWorkflow() | |||
| const { doSyncWorkflowDraft } = useNodesSyncDraft() | |||
| const { checkValid: checkKnowledgeRetrievalValid } = KnowledgeRetrievalDefault | |||
| const { t } = useTranslation() | |||
| const { | |||
| modelList: rerankModelList, | |||
| defaultModel: rerankDefaultModel, | |||
| } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank) | |||
| const { | |||
| currentModel, | |||
| } = useCurrentProviderAndModel( | |||
| rerankModelList, | |||
| rerankDefaultModel | |||
| ? { | |||
| ...rerankDefaultModel, | |||
| provider: rerankDefaultModel.provider.provider, | |||
| } | |||
| : undefined, | |||
| ) | |||
| const handleWorkflowStartRunInWorkflow = useCallback(async () => { | |||
| const { | |||
| @@ -33,6 +60,9 @@ export const useWorkflowStartRun = () => { | |||
| const { getNodes } = store.getState() | |||
| const nodes = getNodes() | |||
| const startNode = nodes.find(node => node.data.type === BlockEnum.Start) | |||
| const knowledgeRetrievalNodes = nodes.filter((node: Node<KnowledgeRetrievalNodeType>) => | |||
| node.data.type === BlockEnum.KnowledgeRetrieval, | |||
| ) | |||
| const startVariables = startNode?.data.variables || [] | |||
| const fileSettings = featuresStore!.getState().features.file | |||
| const { | |||
| @@ -42,6 +72,31 @@ export const useWorkflowStartRun = () => { | |||
| setShowEnvPanel, | |||
| } = workflowStore.getState() | |||
| if (knowledgeRetrievalNodes.length > 0) { | |||
| for (const node of knowledgeRetrievalNodes) { | |||
| if (isFromStartNode(node.id)) { | |||
| const res = checkKnowledgeRetrievalValid(node.data, t) | |||
| if (!res.isValid || !currentModel || !rerankDefaultModel) { | |||
| const errorMessage = res.errorMessage | |||
| if (errorMessage) { | |||
| Toast.notify({ | |||
| type: 'error', | |||
| message: errorMessage, | |||
| }) | |||
| return false | |||
| } | |||
| else { | |||
| Toast.notify({ | |||
| type: 'error', | |||
| message: t('appDebug.datasetConfig.rerankModelRequired'), | |||
| }) | |||
| return false | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| setShowEnvPanel(false) | |||
| if (showDebugAndPreviewPanel) { | |||
| @@ -235,6 +235,33 @@ export const useWorkflow = () => { | |||
| return nodes.filter(node => node.parentId === nodeId) | |||
| }, [store]) | |||
| const isFromStartNode = useCallback((nodeId: string) => { | |||
| const { getNodes } = store.getState() | |||
| const nodes = getNodes() | |||
| const currentNode = nodes.find(node => node.id === nodeId) | |||
| if (!currentNode) | |||
| return false | |||
| if (currentNode.data.type === BlockEnum.Start) | |||
| return true | |||
| const checkPreviousNodes = (node: Node) => { | |||
| const previousNodes = getBeforeNodeById(node.id) | |||
| for (const prevNode of previousNodes) { | |||
| if (prevNode.data.type === BlockEnum.Start) | |||
| return true | |||
| if (checkPreviousNodes(prevNode)) | |||
| return true | |||
| } | |||
| return false | |||
| } | |||
| return checkPreviousNodes(currentNode) | |||
| }, [store, getBeforeNodeById]) | |||
| const handleOutVarRenameChange = useCallback((nodeId: string, oldValeSelector: ValueSelector, newVarSelector: ValueSelector) => { | |||
| const { getNodes, setNodes } = store.getState() | |||
| const afterNodes = getAfterNodesInSameBranch(nodeId) | |||
| @@ -389,6 +416,7 @@ export const useWorkflow = () => { | |||
| checkParallelLimit, | |||
| checkNestedParallelLimit, | |||
| isValidConnection, | |||
| isFromStartNode, | |||
| formatTimeFromNow, | |||
| getNode, | |||
| getBeforeNodeById, | |||
| @@ -172,6 +172,7 @@ const translation = { | |||
| }, | |||
| errorMsg: { | |||
| fieldRequired: '{{field}} is required', | |||
| rerankModelRequired: 'Before turning on the Rerank Model, please confirm that the model has been successfully configured in the settings.', | |||
| authRequired: 'Authorization is required', | |||
| invalidJson: '{{field}} is invalid JSON', | |||
| fields: { | |||
| @@ -172,6 +172,7 @@ const translation = { | |||
| }, | |||
| errorMsg: { | |||
| fieldRequired: '{{field}} 不能为空', | |||
| rerankModelRequired: '开启 Rerank 模型前,请务必确认模型已在设置中成功配置。', | |||
| authRequired: '请先授权', | |||
| invalidJson: '{{field}} 是非法的 JSON', | |||
| fields: { | |||