| @@ -462,18 +462,6 @@ class PublishedRagPipelineApi(Resource): | |||
| if not isinstance(current_user, Account): | |||
| raise Forbidden() | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("knowledge_base_setting", type=dict, location="json", help="Invalid knowledge base setting.") | |||
| args = parser.parse_args() | |||
| if not args.get("knowledge_base_setting"): | |||
| raise ValueError("Missing knowledge base setting.") | |||
| knowledge_base_setting_data = args.get("knowledge_base_setting") | |||
| if not knowledge_base_setting_data: | |||
| raise ValueError("Missing knowledge base setting.") | |||
| knowledge_base_setting = KnowledgeBaseUpdateConfiguration(**knowledge_base_setting_data) | |||
| rag_pipeline_service = RagPipelineService() | |||
| with Session(db.engine) as session: | |||
| pipeline = session.merge(pipeline) | |||
| @@ -481,7 +469,6 @@ class PublishedRagPipelineApi(Resource): | |||
| session=session, | |||
| pipeline=pipeline, | |||
| account=current_user, | |||
| knowledge_base_setting=knowledge_base_setting, | |||
| ) | |||
| pipeline.is_published = True | |||
| pipeline.workflow_id = workflow.id | |||
| @@ -22,11 +22,12 @@ class PluginDatasourceManager(BasePluginClient): | |||
| """ | |||
| def transformer(json_response: dict[str, Any]) -> dict: | |||
| for provider in json_response.get("data", []): | |||
| declaration = provider.get("declaration", {}) or {} | |||
| provider_name = declaration.get("identity", {}).get("name") | |||
| for datasource in declaration.get("datasources", []): | |||
| datasource["identity"]["provider"] = provider_name | |||
| if json_response.get("data"): | |||
| for provider in json_response.get("data", []): | |||
| declaration = provider.get("declaration", {}) or {} | |||
| provider_name = declaration.get("identity", {}).get("name") | |||
| for datasource in declaration.get("datasources", []): | |||
| datasource["identity"]["provider"] = provider_name | |||
| return json_response | |||
| @@ -9,6 +9,7 @@ from core.datasource.entities.datasource_entities import ( | |||
| ) | |||
| from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin | |||
| from core.file import File | |||
| from core.file.enums import FileTransferMethod, FileType | |||
| from core.plugin.impl.exc import PluginDaemonClientSideError | |||
| from core.variables.segments import ArrayAnySegment, FileSegment | |||
| from core.variables.variables import ArrayAnyVariable | |||
| @@ -118,7 +119,12 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): | |||
| }, | |||
| ) | |||
| case DatasourceProviderType.LOCAL_FILE: | |||
| upload_file = db.session.query(UploadFile).filter(UploadFile.id == datasource_info["related_id"]).first() | |||
| related_id = datasource_info.get("related_id") | |||
| if not related_id: | |||
| raise DatasourceNodeError( | |||
| "File is not exist" | |||
| ) | |||
| upload_file = db.session.query(UploadFile).filter(UploadFile.id == related_id).first() | |||
| if not upload_file: | |||
| raise ValueError("Invalid upload file Info") | |||
| @@ -128,14 +134,14 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): | |||
| extension="." + upload_file.extension, | |||
| mime_type=upload_file.mime_type, | |||
| tenant_id=self.tenant_id, | |||
| type=datasource_info.get("type", ""), | |||
| transfer_method=datasource_info.get("transfer_method", ""), | |||
| type=FileType.CUSTOM, | |||
| transfer_method=FileTransferMethod.LOCAL_FILE, | |||
| remote_url=upload_file.source_url, | |||
| related_id=upload_file.id, | |||
| size=upload_file.size, | |||
| storage_key=upload_file.key, | |||
| ) | |||
| variable_pool.add([self.node_id, "file"], [FileSegment(value=file_info)]) | |||
| variable_pool.add([self.node_id, "file"], [file_info]) | |||
| for key, value in datasource_info.items(): | |||
| # construct new key list | |||
| new_key_list = ["file", key] | |||
| @@ -147,7 +153,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): | |||
| inputs=parameters_for_log, | |||
| metadata={NodeRunMetadataKey.DATASOURCE_INFO: datasource_info}, | |||
| outputs={ | |||
| "file_info": file_info, | |||
| "file_info": datasource_info, | |||
| "datasource_type": datasource_type, | |||
| }, | |||
| ) | |||
| @@ -220,7 +226,7 @@ class DatasourceNode(BaseNode[DatasourceNodeData]): | |||
| variable = variable_pool.get(["sys", SystemVariableKey.FILES.value]) | |||
| assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment) | |||
| return list(variable.value) if variable else [] | |||
| def _append_variables_recursively(self, variable_pool: VariablePool, node_id: str, variable_key_list: list[str], variable_value: VariableValue): | |||
| """ | |||
| @@ -53,6 +53,7 @@ from services.entities.knowledge_entities.knowledge_entities import ( | |||
| ) | |||
| from services.entities.knowledge_entities.rag_pipeline_entities import ( | |||
| KnowledgeBaseUpdateConfiguration, | |||
| KnowledgeConfiguration, | |||
| RagPipelineDatasetCreateEntity, | |||
| ) | |||
| from services.errors.account import InvalidActionError, NoPermissionError | |||
| @@ -495,11 +496,11 @@ class DatasetService: | |||
| @staticmethod | |||
| def update_rag_pipeline_dataset_settings(session: Session, | |||
| dataset: Dataset, | |||
| knowledge_base_setting: KnowledgeBaseUpdateConfiguration, | |||
| knowledge_configuration: KnowledgeConfiguration, | |||
| has_published: bool = False): | |||
| if not has_published: | |||
| dataset.chunk_structure = knowledge_base_setting.chunk_structure | |||
| index_method = knowledge_base_setting.index_method | |||
| dataset.chunk_structure = knowledge_configuration.chunk_structure | |||
| index_method = knowledge_configuration.index_method | |||
| dataset.indexing_technique = index_method.indexing_technique | |||
| if index_method == "high_quality": | |||
| model_manager = ModelManager() | |||
| @@ -519,26 +520,26 @@ class DatasetService: | |||
| dataset.keyword_number = index_method.economy_setting.keyword_number | |||
| else: | |||
| raise ValueError("Invalid index method") | |||
| dataset.retrieval_model = knowledge_base_setting.retrieval_setting.model_dump() | |||
| dataset.retrieval_model = knowledge_configuration.retrieval_setting.model_dump() | |||
| session.add(dataset) | |||
| else: | |||
| if dataset.chunk_structure and dataset.chunk_structure != knowledge_base_setting.chunk_structure: | |||
| if dataset.chunk_structure and dataset.chunk_structure != knowledge_configuration.chunk_structure: | |||
| raise ValueError("Chunk structure is not allowed to be updated.") | |||
| action = None | |||
| if dataset.indexing_technique != knowledge_base_setting.index_method.indexing_technique: | |||
| if dataset.indexing_technique != knowledge_configuration.index_method.indexing_technique: | |||
| # if update indexing_technique | |||
| if knowledge_base_setting.index_method.indexing_technique == "economy": | |||
| if knowledge_configuration.index_method.indexing_technique == "economy": | |||
| raise ValueError("Knowledge base indexing technique is not allowed to be updated to economy.") | |||
| elif knowledge_base_setting.index_method.indexing_technique == "high_quality": | |||
| elif knowledge_configuration.index_method.indexing_technique == "high_quality": | |||
| action = "add" | |||
| # get embedding model setting | |||
| try: | |||
| model_manager = ModelManager() | |||
| embedding_model = model_manager.get_model_instance( | |||
| tenant_id=current_user.current_tenant_id, | |||
| provider=knowledge_base_setting.index_method.embedding_setting.embedding_provider_name, | |||
| provider=knowledge_configuration.index_method.embedding_setting.embedding_provider_name, | |||
| model_type=ModelType.TEXT_EMBEDDING, | |||
| model=knowledge_base_setting.index_method.embedding_setting.embedding_model_name, | |||
| model=knowledge_configuration.index_method.embedding_setting.embedding_model_name, | |||
| ) | |||
| dataset.embedding_model = embedding_model.model | |||
| dataset.embedding_model_provider = embedding_model.provider | |||
| @@ -607,9 +608,9 @@ class DatasetService: | |||
| except ProviderTokenNotInitError as ex: | |||
| raise ValueError(ex.description) | |||
| elif dataset.indexing_technique == "economy": | |||
| if dataset.keyword_number != knowledge_base_setting.index_method.economy_setting.keyword_number: | |||
| dataset.keyword_number = knowledge_base_setting.index_method.economy_setting.keyword_number | |||
| dataset.retrieval_model = knowledge_base_setting.retrieval_setting.model_dump() | |||
| if dataset.keyword_number != knowledge_configuration.index_method.economy_setting.keyword_number: | |||
| dataset.keyword_number = knowledge_configuration.index_method.economy_setting.keyword_number | |||
| dataset.retrieval_model = knowledge_configuration.retrieval_setting.model_dump() | |||
| session.add(dataset) | |||
| session.commit() | |||
| if action: | |||
| @@ -47,7 +47,7 @@ from models.workflow import ( | |||
| WorkflowType, | |||
| ) | |||
| from services.dataset_service import DatasetService | |||
| from services.entities.knowledge_entities.rag_pipeline_entities import KnowledgeBaseUpdateConfiguration, PipelineTemplateInfoEntity | |||
| from services.entities.knowledge_entities.rag_pipeline_entities import KnowledgeBaseUpdateConfiguration, KnowledgeConfiguration, PipelineTemplateInfoEntity | |||
| from services.errors.app import WorkflowHashNotEqualError | |||
| from services.rag_pipeline.pipeline_template.pipeline_template_factory import PipelineTemplateRetrievalFactory | |||
| @@ -262,7 +262,6 @@ class RagPipelineService: | |||
| session: Session, | |||
| pipeline: Pipeline, | |||
| account: Account, | |||
| knowledge_base_setting: KnowledgeBaseUpdateConfiguration, | |||
| ) -> Workflow: | |||
| draft_workflow_stmt = select(Workflow).where( | |||
| Workflow.tenant_id == pipeline.tenant_id, | |||
| @@ -291,16 +290,23 @@ class RagPipelineService: | |||
| # commit db session changes | |||
| session.add(workflow) | |||
| # update dataset | |||
| dataset = pipeline.dataset | |||
| if not dataset: | |||
| raise ValueError("Dataset not found") | |||
| DatasetService.update_rag_pipeline_dataset_settings( | |||
| session=session, | |||
| dataset=dataset, | |||
| knowledge_base_setting=knowledge_base_setting, | |||
| has_published=pipeline.is_published | |||
| ) | |||
| graph = workflow.graph_dict | |||
| nodes = graph.get("nodes", []) | |||
| for node in nodes: | |||
| if node.get("data", {}).get("type") == "knowledge_index": | |||
| knowledge_configuration = node.get("data", {}).get("knowledge_configuration", {}) | |||
| knowledge_configuration = KnowledgeConfiguration(**knowledge_configuration) | |||
| # update dataset | |||
| dataset = pipeline.dataset | |||
| if not dataset: | |||
| raise ValueError("Dataset not found") | |||
| DatasetService.update_rag_pipeline_dataset_settings( | |||
| session=session, | |||
| dataset=dataset, | |||
| knowledge_configuration=knowledge_configuration, | |||
| has_published=pipeline.is_published | |||
| ) | |||
| # return new workflow | |||
| return workflow | |||