| @@ -52,7 +52,7 @@ const InstallFromMarketplace = ({ | |||
| <div className='flex items-center justify-between'> | |||
| <div className='system-md-semibold flex cursor-pointer items-center gap-1 text-text-primary' onClick={() => setCollapse(!collapse)}> | |||
| <RiArrowDownSLine className={cn('h-4 w-4', collapse && '-rotate-90')} /> | |||
| {t('common.modelProvider.installProvider')} | |||
| {t('common.modelProvider.installDataSourceProvider')} | |||
| </div> | |||
| <div className='mb-2 flex items-center pt-2'> | |||
| <span className='system-sm-regular pr-1 text-text-tertiary'>{t('common.modelProvider.discoverMore')}</span> | |||
| @@ -86,7 +86,10 @@ const OptionCard = memo(({ | |||
| readonly && 'cursor-not-allowed', | |||
| wrapperClassName && (typeof wrapperClassName === 'function' ? wrapperClassName(isActive) : wrapperClassName), | |||
| )} | |||
| onClick={() => !readonly && enableSelect && id && onClick?.(id)} | |||
| onClick={(e) => { | |||
| e.stopPropagation() | |||
| !readonly && enableSelect && id && onClick?.(id) | |||
| }} | |||
| > | |||
| <div className={cn( | |||
| 'relative flex rounded-t-xl p-2', | |||
| @@ -2,6 +2,7 @@ import type { NodeDefault } from '../../types' | |||
| import type { KnowledgeBaseNodeType } from './types' | |||
| import { genNodeMetaData } from '@/app/components/workflow/utils' | |||
| import { BlockEnum } from '@/app/components/workflow/types' | |||
| import { IndexingType } from '@/app/components/datasets/create/step-two' | |||
| const metaData = genNodeMetaData({ | |||
| sort: 3.1, | |||
| @@ -27,8 +28,17 @@ const nodeDefault: NodeDefault<KnowledgeBaseNodeType> = { | |||
| chunk_structure, | |||
| indexing_technique, | |||
| retrieval_model, | |||
| embedding_model, | |||
| embedding_model_provider, | |||
| index_chunk_variable_selector, | |||
| } = payload | |||
| const { | |||
| search_method, | |||
| reranking_enable, | |||
| reranking_model, | |||
| } = retrieval_model || {} | |||
| if (!chunk_structure) { | |||
| return { | |||
| isValid: false, | |||
| @@ -36,6 +46,13 @@ const nodeDefault: NodeDefault<KnowledgeBaseNodeType> = { | |||
| } | |||
| } | |||
| if (index_chunk_variable_selector.length === 0) { | |||
| return { | |||
| isValid: false, | |||
| errorMessage: t('workflow.nodes.knowledgeBase.chunksVariableIsRequired'), | |||
| } | |||
| } | |||
| if (!indexing_technique) { | |||
| return { | |||
| isValid: false, | |||
| @@ -43,13 +60,27 @@ const nodeDefault: NodeDefault<KnowledgeBaseNodeType> = { | |||
| } | |||
| } | |||
| if (!retrieval_model || !retrieval_model.search_method) { | |||
| if (indexing_technique === IndexingType.QUALIFIED && (!embedding_model || !embedding_model_provider)) { | |||
| return { | |||
| isValid: false, | |||
| errorMessage: t('workflow.nodes.knowledgeBase.embeddingModelIsRequired'), | |||
| } | |||
| } | |||
| if (!retrieval_model || !search_method) { | |||
| return { | |||
| isValid: false, | |||
| errorMessage: t('workflow.nodes.knowledgeBase.retrievalSettingIsRequired'), | |||
| } | |||
| } | |||
| if (reranking_enable && (!reranking_model || !reranking_model.reranking_provider_name || !reranking_model.reranking_model_name)) { | |||
| return { | |||
| isValid: false, | |||
| errorMessage: t('workflow.nodes.knowledgeBase.rerankingModelIsRequired'), | |||
| } | |||
| } | |||
| return { | |||
| isValid: true, | |||
| errorMessage: '', | |||
| @@ -9,13 +9,17 @@ import { | |||
| ChunkStructureEnum, | |||
| IndexMethodEnum, | |||
| RetrievalSearchMethodEnum, | |||
| WeightedScoreEnum, | |||
| } from '../types' | |||
| import type { | |||
| HybridSearchModeEnum, | |||
| KnowledgeBaseNodeType, | |||
| RerankingModel, | |||
| } from '../types' | |||
| import { | |||
| HybridSearchModeEnum, | |||
| } from '../types' | |||
| import { isHighQualitySearchMethod } from '../utils' | |||
| import { DEFAULT_WEIGHTED_SCORE, RerankingModeEnum } from '@/models/datasets' | |||
| export const useConfig = (id: string) => { | |||
| const store = useStoreApi() | |||
| @@ -35,6 +39,25 @@ export const useConfig = (id: string) => { | |||
| }) | |||
| }, [id, handleNodeDataUpdateWithSyncDraft]) | |||
| const getDefaultWeights = useCallback(({ | |||
| embeddingModel, | |||
| embeddingModelProvider, | |||
| }: { | |||
| embeddingModel: string | |||
| embeddingModelProvider: string | |||
| }) => { | |||
| return { | |||
| vector_setting: { | |||
| vector_weight: DEFAULT_WEIGHTED_SCORE.other.semantic, | |||
| embedding_provider_name: embeddingModelProvider || '', | |||
| embedding_model_name: embeddingModel, | |||
| }, | |||
| keyword_setting: { | |||
| keyword_weight: DEFAULT_WEIGHTED_SCORE.other.keyword, | |||
| }, | |||
| } | |||
| }, []) | |||
| const handleChunkStructureChange = useCallback((chunkStructure: ChunkStructureEnum) => { | |||
| const nodeData = getNodeData() | |||
| const { | |||
| @@ -80,39 +103,72 @@ export const useConfig = (id: string) => { | |||
| embeddingModelProvider: string | |||
| }) => { | |||
| const nodeData = getNodeData() | |||
| handleNodeDataUpdate({ | |||
| const defaultWeights = getDefaultWeights({ | |||
| embeddingModel, | |||
| embeddingModelProvider, | |||
| }) | |||
| const changeData = { | |||
| embedding_model: embeddingModel, | |||
| embedding_model_provider: embeddingModelProvider, | |||
| retrieval_model: { | |||
| ...nodeData?.data.retrieval_model, | |||
| vector_setting: { | |||
| ...nodeData?.data.retrieval_model.vector_setting, | |||
| embedding_provider_name: embeddingModelProvider, | |||
| embedding_model_name: embeddingModel, | |||
| }, | |||
| }, | |||
| }) | |||
| }, [getNodeData, handleNodeDataUpdate]) | |||
| } | |||
| if (changeData.retrieval_model.weights) { | |||
| changeData.retrieval_model = { | |||
| ...changeData.retrieval_model, | |||
| weights: { | |||
| ...changeData.retrieval_model.weights, | |||
| vector_setting: { | |||
| ...changeData.retrieval_model.weights.vector_setting, | |||
| embedding_provider_name: embeddingModelProvider, | |||
| embedding_model_name: embeddingModel, | |||
| }, | |||
| }, | |||
| } | |||
| } | |||
| else { | |||
| changeData.retrieval_model = { | |||
| ...changeData.retrieval_model, | |||
| weights: defaultWeights, | |||
| } | |||
| } | |||
| handleNodeDataUpdate(changeData) | |||
| }, [getNodeData, getDefaultWeights, handleNodeDataUpdate]) | |||
| const handleRetrievalSearchMethodChange = useCallback((searchMethod: RetrievalSearchMethodEnum) => { | |||
| const nodeData = getNodeData() | |||
| handleNodeDataUpdate({ | |||
| const changeData = { | |||
| retrieval_model: { | |||
| ...nodeData?.data.retrieval_model, | |||
| search_method: searchMethod, | |||
| reranking_mode: nodeData?.data.retrieval_model.reranking_mode || RerankingModeEnum.RerankingModel, | |||
| }, | |||
| }) | |||
| } | |||
| if (searchMethod === RetrievalSearchMethodEnum.hybrid) { | |||
| changeData.retrieval_model = { | |||
| ...changeData.retrieval_model, | |||
| reranking_enable: changeData.retrieval_model.reranking_mode === RerankingModeEnum.RerankingModel, | |||
| } | |||
| } | |||
| handleNodeDataUpdate(changeData) | |||
| }, [getNodeData, handleNodeDataUpdate]) | |||
| const handleHybridSearchModeChange = useCallback((hybridSearchMode: HybridSearchModeEnum) => { | |||
| const nodeData = getNodeData() | |||
| const defaultWeights = getDefaultWeights({ | |||
| embeddingModel: nodeData?.data.embedding_model || '', | |||
| embeddingModelProvider: nodeData?.data.embedding_model_provider || '', | |||
| }) | |||
| handleNodeDataUpdate({ | |||
| retrieval_model: { | |||
| ...nodeData?.data.retrieval_model, | |||
| reranking_mode: hybridSearchMode, | |||
| reranking_enable: hybridSearchMode === HybridSearchModeEnum.RerankingModel, | |||
| weights: nodeData?.data.retrieval_model.weights || defaultWeights, | |||
| }, | |||
| }) | |||
| }, [getNodeData, handleNodeDataUpdate]) | |||
| }, [getNodeData, getDefaultWeights, handleNodeDataUpdate]) | |||
| const handleRerankingModelEnabledChange = useCallback((rerankingModelEnabled: boolean) => { | |||
| const nodeData = getNodeData() | |||
| @@ -130,11 +186,10 @@ export const useConfig = (id: string) => { | |||
| retrieval_model: { | |||
| ...nodeData?.data.retrieval_model, | |||
| weights: { | |||
| weight_type: 'weighted_score', | |||
| weight_type: WeightedScoreEnum.Customized, | |||
| vector_setting: { | |||
| ...nodeData?.data.retrieval_model.weights?.vector_setting, | |||
| vector_weight: weightedScore.value[0], | |||
| embedding_provider_name: '', | |||
| embedding_model_name: '', | |||
| }, | |||
| keyword_setting: { | |||
| keyword_weight: weightedScore.value[1], | |||
| @@ -493,6 +493,7 @@ const translation = { | |||
| toBeConfigured: 'To be configured', | |||
| configureTip: 'Set up api-key or add model to use', | |||
| installProvider: 'Install model providers', | |||
| installDataSourceProvider: 'Install data source providers', | |||
| discoverMore: 'Discover more in ', | |||
| emptyProviderTitle: 'Model provider not set up', | |||
| emptyProviderTip: 'Please install a model provider first.', | |||
| @@ -955,7 +955,10 @@ const translation = { | |||
| aboutRetrieval: 'about retrieval method.', | |||
| chunkIsRequired: 'Chunk structure is required', | |||
| indexMethodIsRequired: 'Index method is required', | |||
| chunksVariableIsRequired: 'Chunks variable is required', | |||
| embeddingModelIsRequired: 'Embedding model is required', | |||
| retrievalSettingIsRequired: 'Retrieval setting is required', | |||
| rerankingModelIsRequired: 'Reranking model is required', | |||
| }, | |||
| }, | |||
| tracing: { | |||
| @@ -484,6 +484,7 @@ const translation = { | |||
| emptyProviderTitle: 'モデルプロバイダーが設定されていません', | |||
| discoverMore: 'もっと発見する', | |||
| installProvider: 'モデルプロバイダーをインストールする', | |||
| installDataSourceProvider: 'データソースプロバイダーをインストールする', | |||
| configureTip: 'API キーを設定するか、使用するモデルを追加してください', | |||
| toBeConfigured: '設定中', | |||
| emptyProviderTip: '最初にモデルプロバイダーをインストールしてください。', | |||
| @@ -487,6 +487,7 @@ const translation = { | |||
| toBeConfigured: '待配置', | |||
| configureTip: '请配置 API 密钥,添加模型。', | |||
| installProvider: '安装模型供应商', | |||
| installDataSourceProvider: '安装数据源供应商', | |||
| discoverMore: '发现更多就在', | |||
| emptyProviderTitle: '尚未安装模型供应商', | |||
| emptyProviderTip: '请安装模型供应商。', | |||
| @@ -955,7 +955,10 @@ const translation = { | |||
| aboutRetrieval: '关于知识检索。', | |||
| chunkIsRequired: '分段结构是必需的', | |||
| indexMethodIsRequired: '索引方法是必需的', | |||
| chunksVariableIsRequired: 'Chunks 变量是必需的', | |||
| embeddingModelIsRequired: 'Embedding 模型是必需的', | |||
| retrievalSettingIsRequired: '检索设置是必需的', | |||
| rerankingModelIsRequired: 'Reranking 模型是必需的', | |||
| }, | |||
| }, | |||
| tracing: { | |||