| @@ -0,0 +1,53 @@ | |||
| import type { FC } from 'react' | |||
| import { createContext, useCallback, useEffect, useRef } from 'react' | |||
| import { createDatasetsDetailStore } from './store' | |||
| import type { CommonNodeType, Node } from '../types' | |||
| import { BlockEnum } from '../types' | |||
| import type { KnowledgeRetrievalNodeType } from '../nodes/knowledge-retrieval/types' | |||
| import { fetchDatasets } from '@/service/datasets' | |||
| type DatasetsDetailStoreApi = ReturnType<typeof createDatasetsDetailStore> | |||
| type DatasetsDetailContextType = DatasetsDetailStoreApi | undefined | |||
| export const DatasetsDetailContext = createContext<DatasetsDetailContextType>(undefined) | |||
| type DatasetsDetailProviderProps = { | |||
| nodes: Node[] | |||
| children: React.ReactNode | |||
| } | |||
| const DatasetsDetailProvider: FC<DatasetsDetailProviderProps> = ({ | |||
| nodes, | |||
| children, | |||
| }) => { | |||
| const storeRef = useRef<DatasetsDetailStoreApi>() | |||
| if (!storeRef.current) | |||
| storeRef.current = createDatasetsDetailStore() | |||
| const updateDatasetsDetail = useCallback(async (datasetIds: string[]) => { | |||
| const { data: datasetsDetail } = await fetchDatasets({ url: '/datasets', params: { page: 1, ids: datasetIds } }) | |||
| if (datasetsDetail && datasetsDetail.length > 0) | |||
| storeRef.current!.getState().updateDatasetsDetail(datasetsDetail) | |||
| }, []) | |||
| useEffect(() => { | |||
| if (!storeRef.current) return | |||
| const knowledgeRetrievalNodes = nodes.filter(node => node.data.type === BlockEnum.KnowledgeRetrieval) | |||
| const allDatasetIds = knowledgeRetrievalNodes.reduce<string[]>((acc, node) => { | |||
| return Array.from(new Set([...acc, ...(node.data as CommonNodeType<KnowledgeRetrievalNodeType>).dataset_ids])) | |||
| }, []) | |||
| if (allDatasetIds.length === 0) return | |||
| updateDatasetsDetail(allDatasetIds) | |||
| // eslint-disable-next-line react-hooks/exhaustive-deps | |||
| }, []) | |||
| return ( | |||
| <DatasetsDetailContext.Provider value={storeRef.current!}> | |||
| {children} | |||
| </DatasetsDetailContext.Provider> | |||
| ) | |||
| } | |||
| export default DatasetsDetailProvider | |||
| @@ -0,0 +1,38 @@ | |||
| import { useContext } from 'react' | |||
| import { createStore, useStore } from 'zustand' | |||
| import type { DataSet } from '@/models/datasets' | |||
| import { DatasetsDetailContext } from './provider' | |||
| import produce from 'immer' | |||
| type DatasetsDetailStore = { | |||
| datasetsDetail: Record<string, DataSet> | |||
| updateDatasetsDetail: (datasetsDetail: DataSet[]) => void | |||
| } | |||
| export const createDatasetsDetailStore = () => { | |||
| return createStore<DatasetsDetailStore>((set, get) => ({ | |||
| datasetsDetail: {}, | |||
| updateDatasetsDetail: (datasets: DataSet[]) => { | |||
| const oldDatasetsDetail = get().datasetsDetail | |||
| const datasetsDetail = datasets.reduce<Record<string, DataSet>>((acc, dataset) => { | |||
| acc[dataset.id] = dataset | |||
| return acc | |||
| }, {}) | |||
| // Merge new datasets detail into old one | |||
| const newDatasetsDetail = produce(oldDatasetsDetail, (draft) => { | |||
| Object.entries(datasetsDetail).forEach(([key, value]) => { | |||
| draft[key] = value | |||
| }) | |||
| }) | |||
| set({ datasetsDetail: newDatasetsDetail }) | |||
| }, | |||
| })) | |||
| } | |||
| export const useDatasetsDetailStore = <T>(selector: (state: DatasetsDetailStore) => T): T => { | |||
| const store = useContext(DatasetsDetailContext) | |||
| if (!store) | |||
| throw new Error('Missing DatasetsDetailContext.Provider in the tree') | |||
| return useStore(store, selector) | |||
| } | |||
| @@ -160,7 +160,7 @@ const Header: FC = () => { | |||
| const { mutateAsync: publishWorkflow } = usePublishWorkflow(appID!) | |||
| const onPublish = useCallback(async (params?: PublishWorkflowParams) => { | |||
| if (handleCheckBeforePublish()) { | |||
| if (await handleCheckBeforePublish()) { | |||
| const res = await publishWorkflow({ | |||
| title: params?.title || '', | |||
| releaseNotes: params?.releaseNotes || '', | |||
| @@ -1,10 +1,12 @@ | |||
| import { | |||
| useCallback, | |||
| useMemo, | |||
| useRef, | |||
| } from 'react' | |||
| import { useTranslation } from 'react-i18next' | |||
| import { useStoreApi } from 'reactflow' | |||
| import type { | |||
| CommonNodeType, | |||
| Edge, | |||
| Node, | |||
| } from '../types' | |||
| @@ -27,6 +29,10 @@ import { useGetLanguage } from '@/context/i18n' | |||
| import type { AgentNodeType } from '../nodes/agent/types' | |||
| import { useStrategyProviders } from '@/service/use-strategy' | |||
| import { canFindTool } from '@/utils' | |||
| import { useDatasetsDetailStore } from '../datasets-detail-store/store' | |||
| import type { KnowledgeRetrievalNodeType } from '../nodes/knowledge-retrieval/types' | |||
| import type { DataSet } from '@/models/datasets' | |||
| import { fetchDatasets } from '@/service/datasets' | |||
| export const useChecklist = (nodes: Node[], edges: Edge[]) => { | |||
| const { t } = useTranslation() | |||
| @@ -37,6 +43,24 @@ export const useChecklist = (nodes: Node[], edges: Edge[]) => { | |||
| const customTools = useStore(s => s.customTools) | |||
| const workflowTools = useStore(s => s.workflowTools) | |||
| const { data: strategyProviders } = useStrategyProviders() | |||
| const datasetsDetail = useDatasetsDetailStore(s => s.datasetsDetail) | |||
| const getCheckData = useCallback((data: CommonNodeType<{}>) => { | |||
| let checkData = data | |||
| if (data.type === BlockEnum.KnowledgeRetrieval) { | |||
| const datasetIds = (data as CommonNodeType<KnowledgeRetrievalNodeType>).dataset_ids | |||
| const _datasets = datasetIds.reduce<DataSet[]>((acc, id) => { | |||
| if (datasetsDetail[id]) | |||
| acc.push(datasetsDetail[id]) | |||
| return acc | |||
| }, []) | |||
| checkData = { | |||
| ...data, | |||
| _datasets, | |||
| } as CommonNodeType<KnowledgeRetrievalNodeType> | |||
| } | |||
| return checkData | |||
| }, [datasetsDetail]) | |||
| const needWarningNodes = useMemo(() => { | |||
| const list = [] | |||
| @@ -75,7 +99,8 @@ export const useChecklist = (nodes: Node[], edges: Edge[]) => { | |||
| } | |||
| if (node.type === CUSTOM_NODE) { | |||
| const { errorMessage } = nodesExtraData[node.data.type].checkValid(node.data, t, moreDataForCheckValid) | |||
| const checkData = getCheckData(node.data) | |||
| const { errorMessage } = nodesExtraData[node.data.type].checkValid(checkData, t, moreDataForCheckValid) | |||
| if (errorMessage || !validNodes.find(n => n.id === node.id)) { | |||
| list.push({ | |||
| @@ -109,7 +134,7 @@ export const useChecklist = (nodes: Node[], edges: Edge[]) => { | |||
| } | |||
| return list | |||
| }, [nodes, edges, isChatMode, buildInTools, customTools, workflowTools, language, nodesExtraData, t, strategyProviders]) | |||
| }, [nodes, edges, isChatMode, buildInTools, customTools, workflowTools, language, nodesExtraData, t, strategyProviders, getCheckData]) | |||
| return needWarningNodes | |||
| } | |||
| @@ -125,8 +150,31 @@ export const useChecklistBeforePublish = () => { | |||
| const store = useStoreApi() | |||
| const nodesExtraData = useNodesExtraData() | |||
| const { data: strategyProviders } = useStrategyProviders() | |||
| const updateDatasetsDetail = useDatasetsDetailStore(s => s.updateDatasetsDetail) | |||
| const updateTime = useRef(0) | |||
| const getCheckData = useCallback((data: CommonNodeType<{}>, datasets: DataSet[]) => { | |||
| let checkData = data | |||
| if (data.type === BlockEnum.KnowledgeRetrieval) { | |||
| const datasetIds = (data as CommonNodeType<KnowledgeRetrievalNodeType>).dataset_ids | |||
| const datasetsDetail = datasets.reduce<Record<string, DataSet>>((acc, dataset) => { | |||
| acc[dataset.id] = dataset | |||
| return acc | |||
| }, {}) | |||
| const _datasets = datasetIds.reduce<DataSet[]>((acc, id) => { | |||
| if (datasetsDetail[id]) | |||
| acc.push(datasetsDetail[id]) | |||
| return acc | |||
| }, []) | |||
| checkData = { | |||
| ...data, | |||
| _datasets, | |||
| } as CommonNodeType<KnowledgeRetrievalNodeType> | |||
| } | |||
| return checkData | |||
| }, []) | |||
| const handleCheckBeforePublish = useCallback(() => { | |||
| const handleCheckBeforePublish = useCallback(async () => { | |||
| const { | |||
| getNodes, | |||
| edges, | |||
| @@ -141,6 +189,24 @@ export const useChecklistBeforePublish = () => { | |||
| notify({ type: 'error', message: t('workflow.common.maxTreeDepth', { depth: MAX_TREE_DEPTH }) }) | |||
| return false | |||
| } | |||
| // Before publish, we need to fetch datasets detail, in case of the settings of datasets have been changed | |||
| const knowledgeRetrievalNodes = nodes.filter(node => node.data.type === BlockEnum.KnowledgeRetrieval) | |||
| const allDatasetIds = knowledgeRetrievalNodes.reduce<string[]>((acc, node) => { | |||
| return Array.from(new Set([...acc, ...(node.data as CommonNodeType<KnowledgeRetrievalNodeType>).dataset_ids])) | |||
| }, []) | |||
| let datasets: DataSet[] = [] | |||
| if (allDatasetIds.length > 0) { | |||
| updateTime.current = updateTime.current + 1 | |||
| const currUpdateTime = updateTime.current | |||
| const { data: datasetsDetail } = await fetchDatasets({ url: '/datasets', params: { page: 1, ids: allDatasetIds } }) | |||
| if (datasetsDetail && datasetsDetail.length > 0) { | |||
| // avoid old data to overwrite the new data | |||
| if (currUpdateTime < updateTime.current) | |||
| return false | |||
| datasets = datasetsDetail | |||
| updateDatasetsDetail(datasetsDetail) | |||
| } | |||
| } | |||
| for (let i = 0; i < nodes.length; i++) { | |||
| const node = nodes[i] | |||
| @@ -161,7 +227,8 @@ export const useChecklistBeforePublish = () => { | |||
| } | |||
| } | |||
| const { errorMessage } = nodesExtraData[node.data.type as BlockEnum].checkValid(node.data, t, moreDataForCheckValid) | |||
| const checkData = getCheckData(node.data, datasets) | |||
| const { errorMessage } = nodesExtraData[node.data.type as BlockEnum].checkValid(checkData, t, moreDataForCheckValid) | |||
| if (errorMessage) { | |||
| notify({ type: 'error', message: `[${node.data.title}] ${errorMessage}` }) | |||
| @@ -185,7 +252,7 @@ export const useChecklistBeforePublish = () => { | |||
| } | |||
| return true | |||
| }, [store, isChatMode, notify, t, buildInTools, customTools, workflowTools, language, nodesExtraData, strategyProviders]) | |||
| }, [store, isChatMode, notify, t, buildInTools, customTools, workflowTools, language, nodesExtraData, strategyProviders, updateDatasetsDetail, getCheckData]) | |||
| return { | |||
| handleCheckBeforePublish, | |||
| @@ -99,6 +99,7 @@ import { useEventEmitterContextContext } from '@/context/event-emitter' | |||
| import Confirm from '@/app/components/base/confirm' | |||
| import { FILE_EXTS } from '@/app/components/base/prompt-editor/constants' | |||
| import { fetchFileUploadConfig } from '@/service/common' | |||
| import DatasetsDetailProvider from './datasets-detail-store/provider' | |||
| const nodeTypes = { | |||
| [CUSTOM_NODE]: CustomNode, | |||
| @@ -448,11 +449,13 @@ const WorkflowWrap = memo(() => { | |||
| nodes={nodesData} | |||
| edges={edgesData} > | |||
| <FeaturesProvider features={initialFeatures}> | |||
| <Workflow | |||
| nodes={nodesData} | |||
| edges={edgesData} | |||
| viewport={data?.graph.viewport} | |||
| /> | |||
| <DatasetsDetailProvider nodes={nodesData}> | |||
| <Workflow | |||
| nodes={nodesData} | |||
| edges={edgesData} | |||
| viewport={data?.graph.viewport} | |||
| /> | |||
| </DatasetsDetailProvider> | |||
| </FeaturesProvider> | |||
| </WorkflowHistoryProvider> | |||
| </ReactFlowProvider> | |||
| @@ -1,33 +1,30 @@ | |||
| import { type FC, useEffect, useRef, useState } from 'react' | |||
| import { type FC, useEffect, useState } from 'react' | |||
| import React from 'react' | |||
| import type { KnowledgeRetrievalNodeType } from './types' | |||
| import { Folder } from '@/app/components/base/icons/src/vender/solid/files' | |||
| import type { NodeProps } from '@/app/components/workflow/types' | |||
| import { fetchDatasets } from '@/service/datasets' | |||
| import type { DataSet } from '@/models/datasets' | |||
| import { useDatasetsDetailStore } from '../../datasets-detail-store/store' | |||
| const Node: FC<NodeProps<KnowledgeRetrievalNodeType>> = ({ | |||
| data, | |||
| }) => { | |||
| const [selectedDatasets, setSelectedDatasets] = useState<DataSet[]>([]) | |||
| const updateTime = useRef(0) | |||
| useEffect(() => { | |||
| (async () => { | |||
| updateTime.current = updateTime.current + 1 | |||
| const currUpdateTime = updateTime.current | |||
| const datasetsDetail = useDatasetsDetailStore(s => s.datasetsDetail) | |||
| if (data.dataset_ids?.length > 0) { | |||
| const { data: dataSetsWithDetail } = await fetchDatasets({ url: '/datasets', params: { page: 1, ids: data.dataset_ids } }) | |||
| // avoid old data overwrite new data | |||
| if (currUpdateTime < updateTime.current) | |||
| return | |||
| setSelectedDatasets(dataSetsWithDetail) | |||
| } | |||
| else { | |||
| setSelectedDatasets([]) | |||
| } | |||
| })() | |||
| }, [data.dataset_ids]) | |||
| useEffect(() => { | |||
| if (data.dataset_ids?.length > 0) { | |||
| const dataSetsWithDetail = data.dataset_ids.reduce<DataSet[]>((acc, id) => { | |||
| if (datasetsDetail[id]) | |||
| acc.push(datasetsDetail[id]) | |||
| return acc | |||
| }, []) | |||
| setSelectedDatasets(dataSetsWithDetail) | |||
| } | |||
| else { | |||
| setSelectedDatasets([]) | |||
| } | |||
| }, [data.dataset_ids, datasetsDetail]) | |||
| if (!selectedDatasets.length) | |||
| return null | |||
| @@ -41,6 +41,7 @@ import useOneStepRun from '@/app/components/workflow/nodes/_base/hooks/use-one-s | |||
| import { useCurrentProviderAndModel, useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks' | |||
| import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' | |||
| import useAvailableVarList from '@/app/components/workflow/nodes/_base/hooks/use-available-var-list' | |||
| import { useDatasetsDetailStore } from '../../datasets-detail-store/store' | |||
| const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => { | |||
| const { nodesReadOnly: readOnly } = useNodesReadOnly() | |||
| @@ -49,6 +50,7 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => { | |||
| const startNode = getBeforeNodesInSameBranch(id).find(node => node.data.type === BlockEnum.Start) | |||
| const startNodeId = startNode?.id | |||
| const { inputs, setInputs: doSetInputs } = useNodeCrud<KnowledgeRetrievalNodeType>(id, payload) | |||
| const updateDatasetsDetail = useDatasetsDetailStore(s => s.updateDatasetsDetail) | |||
| const inputRef = useRef(inputs) | |||
| @@ -218,15 +220,12 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => { | |||
| (async () => { | |||
| const inputs = inputRef.current | |||
| const datasetIds = inputs.dataset_ids | |||
| let _datasets = selectedDatasets | |||
| if (datasetIds?.length > 0) { | |||
| const { data: dataSetsWithDetail } = await fetchDatasets({ url: '/datasets', params: { page: 1, ids: datasetIds } as any }) | |||
| _datasets = dataSetsWithDetail | |||
| setSelectedDatasets(dataSetsWithDetail) | |||
| } | |||
| const newInputs = produce(inputs, (draft) => { | |||
| draft.dataset_ids = datasetIds | |||
| draft._datasets = _datasets | |||
| }) | |||
| setInputs(newInputs) | |||
| setSelectedDatasetsLoaded(true) | |||
| @@ -256,7 +255,6 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => { | |||
| } = getSelectedDatasetsMode(newDatasets) | |||
| const newInputs = produce(inputs, (draft) => { | |||
| draft.dataset_ids = newDatasets.map(d => d.id) | |||
| draft._datasets = newDatasets | |||
| if (payload.retrieval_mode === RETRIEVE_TYPE.multiWay && newDatasets.length > 0) { | |||
| const multipleRetrievalConfig = draft.multiple_retrieval_config | |||
| @@ -266,6 +264,7 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => { | |||
| }) | |||
| } | |||
| }) | |||
| updateDatasetsDetail(newDatasets) | |||
| setInputs(newInputs) | |||
| setSelectedDatasets(newDatasets) | |||
| @@ -275,7 +274,7 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => { | |||
| || allExternal | |||
| ) | |||
| setRerankModelOpen(true) | |||
| }, [inputs, setInputs, payload.retrieval_mode, selectedDatasets, currentRerankModel, currentRerankProvider]) | |||
| }, [inputs, setInputs, payload.retrieval_mode, selectedDatasets, currentRerankModel, currentRerankProvider, updateDatasetsDetail]) | |||
| const filterVar = useCallback((varPayload: Var) => { | |||
| return varPayload.type === VarType.string | |||