| args["doc_form"], | args["doc_form"], | ||||
| args["doc_language"], | args["doc_language"], | ||||
| ) | ) | ||||
| return response, 200 | |||||
| return response.model_dump(), 200 | |||||
| class DataSourceNotionDatasetSyncApi(Resource): | class DataSourceNotionDatasetSyncApi(Resource): |
| except Exception as e: | except Exception as e: | ||||
| raise IndexingEstimateError(str(e)) | raise IndexingEstimateError(str(e)) | ||||
| return response, 200 | |||||
| return response.model_dump(), 200 | |||||
| class DatasetRelatedAppListApi(Resource): | class DatasetRelatedAppListApi(Resource): | ||||
| }, 200 | }, 200 | ||||
| class DatasetAutoDisableLogApi(Resource): | |||||
| @setup_required | |||||
| @login_required | |||||
| @account_initialization_required | |||||
| def get(self, dataset_id): | |||||
| dataset_id_str = str(dataset_id) | |||||
| dataset = DatasetService.get_dataset(dataset_id_str) | |||||
| if dataset is None: | |||||
| raise NotFound("Dataset not found.") | |||||
| return DatasetService.get_dataset_auto_disable_logs(dataset_id_str), 200 | |||||
| api.add_resource(DatasetListApi, "/datasets") | api.add_resource(DatasetListApi, "/datasets") | ||||
| api.add_resource(DatasetApi, "/datasets/<uuid:dataset_id>") | api.add_resource(DatasetApi, "/datasets/<uuid:dataset_id>") | ||||
| api.add_resource(DatasetUseCheckApi, "/datasets/<uuid:dataset_id>/use-check") | api.add_resource(DatasetUseCheckApi, "/datasets/<uuid:dataset_id>/use-check") | ||||
| api.add_resource(DatasetRetrievalSettingApi, "/datasets/retrieval-setting") | api.add_resource(DatasetRetrievalSettingApi, "/datasets/retrieval-setting") | ||||
| api.add_resource(DatasetRetrievalSettingMockApi, "/datasets/retrieval-setting/<string:vector_type>") | api.add_resource(DatasetRetrievalSettingMockApi, "/datasets/retrieval-setting/<string:vector_type>") | ||||
| api.add_resource(DatasetPermissionUserListApi, "/datasets/<uuid:dataset_id>/permission-part-users") | api.add_resource(DatasetPermissionUserListApi, "/datasets/<uuid:dataset_id>/permission-part-users") | ||||
| api.add_resource(DatasetAutoDisableLogApi, "/datasets/<uuid:dataset_id>/auto-disable-logs") |
| from libs.login import login_required | from libs.login import login_required | ||||
| from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile | from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile | ||||
| from services.dataset_service import DatasetService, DocumentService | from services.dataset_service import DatasetService, DocumentService | ||||
| from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig | |||||
| from tasks.add_document_to_index_task import add_document_to_index_task | from tasks.add_document_to_index_task import add_document_to_index_task | ||||
| from tasks.remove_document_from_index_task import remove_document_from_index_task | from tasks.remove_document_from_index_task import remove_document_from_index_task | ||||
| parser.add_argument("duplicate", type=bool, default=True, nullable=False, location="json") | parser.add_argument("duplicate", type=bool, default=True, nullable=False, location="json") | ||||
| parser.add_argument("original_document_id", type=str, required=False, location="json") | parser.add_argument("original_document_id", type=str, required=False, location="json") | ||||
| parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json") | parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json") | ||||
| parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json") | |||||
| parser.add_argument( | parser.add_argument( | ||||
| "doc_language", type=str, default="English", required=False, nullable=False, location="json" | "doc_language", type=str, default="English", required=False, nullable=False, location="json" | ||||
| ) | ) | ||||
| parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json") | |||||
| args = parser.parse_args() | args = parser.parse_args() | ||||
| knowledge_config = KnowledgeConfig(**args) | |||||
| if not dataset.indexing_technique and not args["indexing_technique"]: | |||||
| if not dataset.indexing_technique and not knowledge_config.indexing_technique: | |||||
| raise ValueError("indexing_technique is required.") | raise ValueError("indexing_technique is required.") | ||||
| # validate args | # validate args | ||||
| DocumentService.document_create_args_validate(args) | |||||
| DocumentService.document_create_args_validate(knowledge_config) | |||||
| try: | try: | ||||
| documents, batch = DocumentService.save_document_with_dataset_id(dataset, args, current_user) | |||||
| documents, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, current_user) | |||||
| except ProviderTokenNotInitError as ex: | except ProviderTokenNotInitError as ex: | ||||
| raise ProviderNotInitializeError(ex.description) | raise ProviderNotInitializeError(ex.description) | ||||
| except QuotaExceededError: | except QuotaExceededError: | ||||
| return {"documents": documents, "batch": batch} | return {"documents": documents, "batch": batch} | ||||
| @setup_required | |||||
| @login_required | |||||
| @account_initialization_required | |||||
| def delete(self, dataset_id): | |||||
| dataset_id = str(dataset_id) | |||||
| dataset = DatasetService.get_dataset(dataset_id) | |||||
| if dataset is None: | |||||
| raise NotFound("Dataset not found.") | |||||
| # check user's model setting | |||||
| DatasetService.check_dataset_model_setting(dataset) | |||||
| try: | |||||
| document_ids = request.args.getlist("document_id") | |||||
| DocumentService.delete_documents(dataset, document_ids) | |||||
| except services.errors.document.DocumentIndexingError: | |||||
| raise DocumentIndexingError("Cannot delete document during indexing.") | |||||
| return {"result": "success"}, 204 | |||||
| class DatasetInitApi(Resource): | class DatasetInitApi(Resource): | ||||
| @setup_required | @setup_required | ||||
| # The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator | # The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator | ||||
| if not current_user.is_dataset_editor: | if not current_user.is_dataset_editor: | ||||
| raise Forbidden() | raise Forbidden() | ||||
| if args["indexing_technique"] == "high_quality": | |||||
| if args["embedding_model"] is None or args["embedding_model_provider"] is None: | |||||
| knowledge_config = KnowledgeConfig(**args) | |||||
| if knowledge_config.indexing_technique == "high_quality": | |||||
| if knowledge_config.embedding_model is None or knowledge_config.embedding_model_provider is None: | |||||
| raise ValueError("embedding model and embedding model provider are required for high quality indexing.") | raise ValueError("embedding model and embedding model provider are required for high quality indexing.") | ||||
| try: | try: | ||||
| model_manager = ModelManager() | model_manager = ModelManager() | ||||
| raise ProviderNotInitializeError(ex.description) | raise ProviderNotInitializeError(ex.description) | ||||
| # validate args | # validate args | ||||
| DocumentService.document_create_args_validate(args) | |||||
| DocumentService.document_create_args_validate(knowledge_config) | |||||
| try: | try: | ||||
| dataset, documents, batch = DocumentService.save_document_without_dataset_id( | dataset, documents, batch = DocumentService.save_document_without_dataset_id( | ||||
| tenant_id=current_user.current_tenant_id, document_data=args, account=current_user | |||||
| tenant_id=current_user.current_tenant_id, knowledge_config=knowledge_config, account=current_user | |||||
| ) | ) | ||||
| except ProviderTokenNotInitError as ex: | except ProviderTokenNotInitError as ex: | ||||
| raise ProviderNotInitializeError(ex.description) | raise ProviderNotInitializeError(ex.description) | ||||
| except Exception as e: | except Exception as e: | ||||
| raise IndexingEstimateError(str(e)) | raise IndexingEstimateError(str(e)) | ||||
| return response | |||||
| return response.model_dump(), 200 | |||||
| class DocumentBatchIndexingEstimateApi(DocumentResource): | class DocumentBatchIndexingEstimateApi(DocumentResource): | ||||
| documents = self.get_batch_documents(dataset_id, batch) | documents = self.get_batch_documents(dataset_id, batch) | ||||
| response = {"tokens": 0, "total_price": 0, "currency": "USD", "total_segments": 0, "preview": []} | response = {"tokens": 0, "total_price": 0, "currency": "USD", "total_segments": 0, "preview": []} | ||||
| if not documents: | if not documents: | ||||
| return response | |||||
| return response, 200 | |||||
| data_process_rule = documents[0].dataset_process_rule | data_process_rule = documents[0].dataset_process_rule | ||||
| data_process_rule_dict = data_process_rule.to_dict() | data_process_rule_dict = data_process_rule.to_dict() | ||||
| info_list = [] | info_list = [] | ||||
| raise ProviderNotInitializeError(ex.description) | raise ProviderNotInitializeError(ex.description) | ||||
| except Exception as e: | except Exception as e: | ||||
| raise IndexingEstimateError(str(e)) | raise IndexingEstimateError(str(e)) | ||||
| return response | |||||
| return response.model_dump(), 200 | |||||
| class DocumentBatchIndexingStatusApi(DocumentResource): | class DocumentBatchIndexingStatusApi(DocumentResource): | ||||
| if metadata == "only": | if metadata == "only": | ||||
| response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata} | response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata} | ||||
| elif metadata == "without": | elif metadata == "without": | ||||
| process_rules = DatasetService.get_process_rules(dataset_id) | |||||
| dataset_process_rules = DatasetService.get_process_rules(dataset_id) | |||||
| document_process_rules = document.dataset_process_rule.to_dict() | |||||
| data_source_info = document.data_source_detail_dict | data_source_info = document.data_source_detail_dict | ||||
| response = { | response = { | ||||
| "id": document.id, | "id": document.id, | ||||
| "data_source_type": document.data_source_type, | "data_source_type": document.data_source_type, | ||||
| "data_source_info": data_source_info, | "data_source_info": data_source_info, | ||||
| "dataset_process_rule_id": document.dataset_process_rule_id, | "dataset_process_rule_id": document.dataset_process_rule_id, | ||||
| "dataset_process_rule": process_rules, | |||||
| "dataset_process_rule": dataset_process_rules, | |||||
| "document_process_rule": document_process_rules, | |||||
| "name": document.name, | "name": document.name, | ||||
| "created_from": document.created_from, | "created_from": document.created_from, | ||||
| "created_by": document.created_by, | "created_by": document.created_by, | ||||
| "doc_language": document.doc_language, | "doc_language": document.doc_language, | ||||
| } | } | ||||
| else: | else: | ||||
| process_rules = DatasetService.get_process_rules(dataset_id) | |||||
| dataset_process_rules = DatasetService.get_process_rules(dataset_id) | |||||
| document_process_rules = document.dataset_process_rule.to_dict() | |||||
| data_source_info = document.data_source_detail_dict | data_source_info = document.data_source_detail_dict | ||||
| response = { | response = { | ||||
| "id": document.id, | "id": document.id, | ||||
| "data_source_type": document.data_source_type, | "data_source_type": document.data_source_type, | ||||
| "data_source_info": data_source_info, | "data_source_info": data_source_info, | ||||
| "dataset_process_rule_id": document.dataset_process_rule_id, | "dataset_process_rule_id": document.dataset_process_rule_id, | ||||
| "dataset_process_rule": process_rules, | |||||
| "dataset_process_rule": dataset_process_rules, | |||||
| "document_process_rule": document_process_rules, | |||||
| "name": document.name, | "name": document.name, | ||||
| "created_from": document.created_from, | "created_from": document.created_from, | ||||
| "created_by": document.created_by, | "created_by": document.created_by, | ||||
| @login_required | @login_required | ||||
| @account_initialization_required | @account_initialization_required | ||||
| @cloud_edition_billing_resource_check("vector_space") | @cloud_edition_billing_resource_check("vector_space") | ||||
| def patch(self, dataset_id, document_id, action): | |||||
| def patch(self, dataset_id, action): | |||||
| dataset_id = str(dataset_id) | dataset_id = str(dataset_id) | ||||
| document_id = str(document_id) | |||||
| dataset = DatasetService.get_dataset(dataset_id) | dataset = DatasetService.get_dataset(dataset_id) | ||||
| if dataset is None: | if dataset is None: | ||||
| raise NotFound("Dataset not found.") | raise NotFound("Dataset not found.") | ||||
| # check user's permission | # check user's permission | ||||
| DatasetService.check_dataset_permission(dataset, current_user) | DatasetService.check_dataset_permission(dataset, current_user) | ||||
| document = self.get_document(dataset_id, document_id) | |||||
| document_ids = request.args.getlist("document_id") | |||||
| for document_id in document_ids: | |||||
| document = self.get_document(dataset_id, document_id) | |||||
| indexing_cache_key = "document_{}_indexing".format(document.id) | |||||
| cache_result = redis_client.get(indexing_cache_key) | |||||
| if cache_result is not None: | |||||
| raise InvalidActionError("Document is being indexed, please try again later") | |||||
| indexing_cache_key = "document_{}_indexing".format(document.id) | |||||
| cache_result = redis_client.get(indexing_cache_key) | |||||
| if cache_result is not None: | |||||
| raise InvalidActionError(f"Document:{document.name} is being indexed, please try again later") | |||||
| if action == "enable": | |||||
| if document.enabled: | |||||
| raise InvalidActionError("Document already enabled.") | |||||
| if action == "enable": | |||||
| if document.enabled: | |||||
| continue | |||||
| document.enabled = True | |||||
| document.disabled_at = None | |||||
| document.disabled_by = None | |||||
| document.updated_at = datetime.now(UTC).replace(tzinfo=None) | |||||
| db.session.commit() | |||||
| document.enabled = True | |||||
| document.disabled_at = None | |||||
| document.disabled_by = None | |||||
| document.updated_at = datetime.now(UTC).replace(tzinfo=None) | |||||
| db.session.commit() | |||||
| # Set cache to prevent indexing the same document multiple times | |||||
| redis_client.setex(indexing_cache_key, 600, 1) | |||||
| # Set cache to prevent indexing the same document multiple times | |||||
| redis_client.setex(indexing_cache_key, 600, 1) | |||||
| add_document_to_index_task.delay(document_id) | |||||
| add_document_to_index_task.delay(document_id) | |||||
| elif action == "disable": | |||||
| if not document.completed_at or document.indexing_status != "completed": | |||||
| raise InvalidActionError(f"Document: {document.name} is not completed.") | |||||
| if not document.enabled: | |||||
| continue | |||||
| return {"result": "success"}, 200 | |||||
| document.enabled = False | |||||
| document.disabled_at = datetime.now(UTC).replace(tzinfo=None) | |||||
| document.disabled_by = current_user.id | |||||
| document.updated_at = datetime.now(UTC).replace(tzinfo=None) | |||||
| db.session.commit() | |||||
| elif action == "disable": | |||||
| if not document.completed_at or document.indexing_status != "completed": | |||||
| raise InvalidActionError("Document is not completed.") | |||||
| if not document.enabled: | |||||
| raise InvalidActionError("Document already disabled.") | |||||
| # Set cache to prevent indexing the same document multiple times | |||||
| redis_client.setex(indexing_cache_key, 600, 1) | |||||
| document.enabled = False | |||||
| document.disabled_at = datetime.now(UTC).replace(tzinfo=None) | |||||
| document.disabled_by = current_user.id | |||||
| document.updated_at = datetime.now(UTC).replace(tzinfo=None) | |||||
| db.session.commit() | |||||
| remove_document_from_index_task.delay(document_id) | |||||
| # Set cache to prevent indexing the same document multiple times | |||||
| redis_client.setex(indexing_cache_key, 600, 1) | |||||
| elif action == "archive": | |||||
| if document.archived: | |||||
| continue | |||||
| remove_document_from_index_task.delay(document_id) | |||||
| document.archived = True | |||||
| document.archived_at = datetime.now(UTC).replace(tzinfo=None) | |||||
| document.archived_by = current_user.id | |||||
| document.updated_at = datetime.now(UTC).replace(tzinfo=None) | |||||
| db.session.commit() | |||||
| return {"result": "success"}, 200 | |||||
| if document.enabled: | |||||
| # Set cache to prevent indexing the same document multiple times | |||||
| redis_client.setex(indexing_cache_key, 600, 1) | |||||
| elif action == "archive": | |||||
| if document.archived: | |||||
| raise InvalidActionError("Document already archived.") | |||||
| remove_document_from_index_task.delay(document_id) | |||||
| document.archived = True | |||||
| document.archived_at = datetime.now(UTC).replace(tzinfo=None) | |||||
| document.archived_by = current_user.id | |||||
| document.updated_at = datetime.now(UTC).replace(tzinfo=None) | |||||
| db.session.commit() | |||||
| elif action == "un_archive": | |||||
| if not document.archived: | |||||
| continue | |||||
| document.archived = False | |||||
| document.archived_at = None | |||||
| document.archived_by = None | |||||
| document.updated_at = datetime.now(UTC).replace(tzinfo=None) | |||||
| db.session.commit() | |||||
| if document.enabled: | |||||
| # Set cache to prevent indexing the same document multiple times | # Set cache to prevent indexing the same document multiple times | ||||
| redis_client.setex(indexing_cache_key, 600, 1) | redis_client.setex(indexing_cache_key, 600, 1) | ||||
| remove_document_from_index_task.delay(document_id) | |||||
| return {"result": "success"}, 200 | |||||
| elif action == "un_archive": | |||||
| if not document.archived: | |||||
| raise InvalidActionError("Document is not archived.") | |||||
| document.archived = False | |||||
| document.archived_at = None | |||||
| document.archived_by = None | |||||
| document.updated_at = datetime.now(UTC).replace(tzinfo=None) | |||||
| db.session.commit() | |||||
| # Set cache to prevent indexing the same document multiple times | |||||
| redis_client.setex(indexing_cache_key, 600, 1) | |||||
| add_document_to_index_task.delay(document_id) | |||||
| add_document_to_index_task.delay(document_id) | |||||
| return {"result": "success"}, 200 | |||||
| else: | |||||
| raise InvalidActionError() | |||||
| else: | |||||
| raise InvalidActionError() | |||||
| return {"result": "success"}, 200 | |||||
| class DocumentPauseApi(DocumentResource): | class DocumentPauseApi(DocumentResource): | ||||
| ) | ) | ||||
| api.add_resource(DocumentDeleteApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>") | api.add_resource(DocumentDeleteApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>") | ||||
| api.add_resource(DocumentMetadataApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/metadata") | api.add_resource(DocumentMetadataApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/metadata") | ||||
| api.add_resource(DocumentStatusApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/status/<string:action>") | |||||
| api.add_resource(DocumentStatusApi, "/datasets/<uuid:dataset_id>/documents/status/<string:action>/batch") | |||||
| api.add_resource(DocumentPauseApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/pause") | api.add_resource(DocumentPauseApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/pause") | ||||
| api.add_resource(DocumentRecoverApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/resume") | api.add_resource(DocumentRecoverApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/resume") | ||||
| api.add_resource(DocumentRetryApi, "/datasets/<uuid:dataset_id>/retry") | api.add_resource(DocumentRetryApi, "/datasets/<uuid:dataset_id>/retry") |
| import uuid | import uuid | ||||
| from datetime import UTC, datetime | |||||
| import pandas as pd | import pandas as pd | ||||
| from flask import request | from flask import request | ||||
| import services | import services | ||||
| from controllers.console import api | from controllers.console import api | ||||
| from controllers.console.app.error import ProviderNotInitializeError | from controllers.console.app.error import ProviderNotInitializeError | ||||
| from controllers.console.datasets.error import InvalidActionError, NoFileUploadedError, TooManyFilesError | |||||
| from controllers.console.datasets.error import ( | |||||
| ChildChunkDeleteIndexError, | |||||
| ChildChunkIndexingError, | |||||
| InvalidActionError, | |||||
| NoFileUploadedError, | |||||
| TooManyFilesError, | |||||
| ) | |||||
| from controllers.console.wraps import ( | from controllers.console.wraps import ( | ||||
| account_initialization_required, | account_initialization_required, | ||||
| cloud_edition_billing_knowledge_limit_check, | cloud_edition_billing_knowledge_limit_check, | ||||
| from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError | from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError | ||||
| from core.model_manager import ModelManager | from core.model_manager import ModelManager | ||||
| from core.model_runtime.entities.model_entities import ModelType | from core.model_runtime.entities.model_entities import ModelType | ||||
| from extensions.ext_database import db | |||||
| from extensions.ext_redis import redis_client | from extensions.ext_redis import redis_client | ||||
| from fields.segment_fields import segment_fields | |||||
| from fields.segment_fields import child_chunk_fields, segment_fields | |||||
| from libs.login import login_required | from libs.login import login_required | ||||
| from models import DocumentSegment | |||||
| from models.dataset import ChildChunk, DocumentSegment | |||||
| from services.dataset_service import DatasetService, DocumentService, SegmentService | from services.dataset_service import DatasetService, DocumentService, SegmentService | ||||
| from services.entities.knowledge_entities.knowledge_entities import ChildChunkUpdateArgs, SegmentUpdateArgs | |||||
| from services.errors.chunk import ChildChunkDeleteIndexError as ChildChunkDeleteIndexServiceError | |||||
| from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingServiceError | |||||
| from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task | from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task | ||||
| from tasks.disable_segment_from_index_task import disable_segment_from_index_task | |||||
| from tasks.enable_segment_to_index_task import enable_segment_to_index_task | |||||
| class DatasetDocumentSegmentListApi(Resource): | class DatasetDocumentSegmentListApi(Resource): | ||||
| raise NotFound("Document not found.") | raise NotFound("Document not found.") | ||||
| parser = reqparse.RequestParser() | parser = reqparse.RequestParser() | ||||
| parser.add_argument("last_id", type=str, default=None, location="args") | |||||
| parser.add_argument("limit", type=int, default=20, location="args") | parser.add_argument("limit", type=int, default=20, location="args") | ||||
| parser.add_argument("status", type=str, action="append", default=[], location="args") | parser.add_argument("status", type=str, action="append", default=[], location="args") | ||||
| parser.add_argument("hit_count_gte", type=int, default=None, location="args") | parser.add_argument("hit_count_gte", type=int, default=None, location="args") | ||||
| parser.add_argument("enabled", type=str, default="all", location="args") | parser.add_argument("enabled", type=str, default="all", location="args") | ||||
| parser.add_argument("keyword", type=str, default=None, location="args") | parser.add_argument("keyword", type=str, default=None, location="args") | ||||
| parser.add_argument("page", type=int, default=1, location="args") | |||||
| args = parser.parse_args() | args = parser.parse_args() | ||||
| last_id = args["last_id"] | |||||
| page = args["page"] | |||||
| limit = min(args["limit"], 100) | limit = min(args["limit"], 100) | ||||
| status_list = args["status"] | status_list = args["status"] | ||||
| hit_count_gte = args["hit_count_gte"] | hit_count_gte = args["hit_count_gte"] | ||||
| query = DocumentSegment.query.filter( | query = DocumentSegment.query.filter( | ||||
| DocumentSegment.document_id == str(document_id), DocumentSegment.tenant_id == current_user.current_tenant_id | DocumentSegment.document_id == str(document_id), DocumentSegment.tenant_id == current_user.current_tenant_id | ||||
| ) | |||||
| if last_id is not None: | |||||
| last_segment = db.session.get(DocumentSegment, str(last_id)) | |||||
| if last_segment: | |||||
| query = query.filter(DocumentSegment.position > last_segment.position) | |||||
| else: | |||||
| return {"data": [], "has_more": False, "limit": limit}, 200 | |||||
| ).order_by(DocumentSegment.position.asc()) | |||||
| if status_list: | if status_list: | ||||
| query = query.filter(DocumentSegment.status.in_(status_list)) | query = query.filter(DocumentSegment.status.in_(status_list)) | ||||
| elif args["enabled"].lower() == "false": | elif args["enabled"].lower() == "false": | ||||
| query = query.filter(DocumentSegment.enabled == False) | query = query.filter(DocumentSegment.enabled == False) | ||||
| total = query.count() | |||||
| segments = query.order_by(DocumentSegment.position).limit(limit + 1).all() | |||||
| has_more = False | |||||
| if len(segments) > limit: | |||||
| has_more = True | |||||
| segments = segments[:-1] | |||||
| segments = query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False) | |||||
| return { | |||||
| "data": marshal(segments, segment_fields), | |||||
| "doc_form": document.doc_form, | |||||
| "has_more": has_more, | |||||
| response = { | |||||
| "data": marshal(segments.items, segment_fields), | |||||
| "limit": limit, | "limit": limit, | ||||
| "total": total, | |||||
| }, 200 | |||||
| "total": segments.total, | |||||
| "total_pages": segments.pages, | |||||
| "page": page, | |||||
| } | |||||
| return response, 200 | |||||
| @setup_required | |||||
| @login_required | |||||
| @account_initialization_required | |||||
| def delete(self, dataset_id, document_id): | |||||
| # check dataset | |||||
| dataset_id = str(dataset_id) | |||||
| dataset = DatasetService.get_dataset(dataset_id) | |||||
| if not dataset: | |||||
| raise NotFound("Dataset not found.") | |||||
| # check user's model setting | |||||
| DatasetService.check_dataset_model_setting(dataset) | |||||
| # check document | |||||
| document_id = str(document_id) | |||||
| document = DocumentService.get_document(dataset_id, document_id) | |||||
| if not document: | |||||
| raise NotFound("Document not found.") | |||||
| segment_ids = request.args.getlist("segment_id") | |||||
| # The role of the current user in the ta table must be admin or owner | |||||
| if not current_user.is_editor: | |||||
| raise Forbidden() | |||||
| try: | |||||
| DatasetService.check_dataset_permission(dataset, current_user) | |||||
| except services.errors.account.NoPermissionError as e: | |||||
| raise Forbidden(str(e)) | |||||
| SegmentService.delete_segments(segment_ids, document, dataset) | |||||
| return {"result": "success"}, 200 | |||||
| class DatasetDocumentSegmentApi(Resource): | class DatasetDocumentSegmentApi(Resource): | ||||
| @login_required | @login_required | ||||
| @account_initialization_required | @account_initialization_required | ||||
| @cloud_edition_billing_resource_check("vector_space") | @cloud_edition_billing_resource_check("vector_space") | ||||
| def patch(self, dataset_id, segment_id, action): | |||||
| def patch(self, dataset_id, document_id, action): | |||||
| dataset_id = str(dataset_id) | dataset_id = str(dataset_id) | ||||
| dataset = DatasetService.get_dataset(dataset_id) | dataset = DatasetService.get_dataset(dataset_id) | ||||
| if not dataset: | if not dataset: | ||||
| raise NotFound("Dataset not found.") | raise NotFound("Dataset not found.") | ||||
| document_id = str(document_id) | |||||
| document = DocumentService.get_document(dataset_id, document_id) | |||||
| if not document: | |||||
| raise NotFound("Document not found.") | |||||
| # check user's model setting | # check user's model setting | ||||
| DatasetService.check_dataset_model_setting(dataset) | DatasetService.check_dataset_model_setting(dataset) | ||||
| # The role of the current user in the ta table must be admin, owner, or editor | # The role of the current user in the ta table must be admin, owner, or editor | ||||
| ) | ) | ||||
| except ProviderTokenNotInitError as ex: | except ProviderTokenNotInitError as ex: | ||||
| raise ProviderNotInitializeError(ex.description) | raise ProviderNotInitializeError(ex.description) | ||||
| segment_ids = request.args.getlist("segment_id") | |||||
| segment = DocumentSegment.query.filter( | |||||
| DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id | |||||
| ).first() | |||||
| if not segment: | |||||
| raise NotFound("Segment not found.") | |||||
| if segment.status != "completed": | |||||
| raise NotFound("Segment is not completed, enable or disable function is not allowed") | |||||
| document_indexing_cache_key = "document_{}_indexing".format(segment.document_id) | |||||
| document_indexing_cache_key = "document_{}_indexing".format(document.id) | |||||
| cache_result = redis_client.get(document_indexing_cache_key) | cache_result = redis_client.get(document_indexing_cache_key) | ||||
| if cache_result is not None: | if cache_result is not None: | ||||
| raise InvalidActionError("Document is being indexed, please try again later") | raise InvalidActionError("Document is being indexed, please try again later") | ||||
| indexing_cache_key = "segment_{}_indexing".format(segment.id) | |||||
| cache_result = redis_client.get(indexing_cache_key) | |||||
| if cache_result is not None: | |||||
| raise InvalidActionError("Segment is being indexed, please try again later") | |||||
| if action == "enable": | |||||
| if segment.enabled: | |||||
| raise InvalidActionError("Segment is already enabled.") | |||||
| segment.enabled = True | |||||
| segment.disabled_at = None | |||||
| segment.disabled_by = None | |||||
| db.session.commit() | |||||
| # Set cache to prevent indexing the same segment multiple times | |||||
| redis_client.setex(indexing_cache_key, 600, 1) | |||||
| enable_segment_to_index_task.delay(segment.id) | |||||
| return {"result": "success"}, 200 | |||||
| elif action == "disable": | |||||
| if not segment.enabled: | |||||
| raise InvalidActionError("Segment is already disabled.") | |||||
| segment.enabled = False | |||||
| segment.disabled_at = datetime.now(UTC).replace(tzinfo=None) | |||||
| segment.disabled_by = current_user.id | |||||
| db.session.commit() | |||||
| # Set cache to prevent indexing the same segment multiple times | |||||
| redis_client.setex(indexing_cache_key, 600, 1) | |||||
| disable_segment_from_index_task.delay(segment.id) | |||||
| return {"result": "success"}, 200 | |||||
| else: | |||||
| raise InvalidActionError() | |||||
| try: | |||||
| SegmentService.update_segments_status(segment_ids, action, dataset, document) | |||||
| except Exception as e: | |||||
| raise InvalidActionError(str(e)) | |||||
| return {"result": "success"}, 200 | |||||
| class DatasetDocumentSegmentAddApi(Resource): | class DatasetDocumentSegmentAddApi(Resource): | ||||
| parser.add_argument("content", type=str, required=True, nullable=False, location="json") | parser.add_argument("content", type=str, required=True, nullable=False, location="json") | ||||
| parser.add_argument("answer", type=str, required=False, nullable=True, location="json") | parser.add_argument("answer", type=str, required=False, nullable=True, location="json") | ||||
| parser.add_argument("keywords", type=list, required=False, nullable=True, location="json") | parser.add_argument("keywords", type=list, required=False, nullable=True, location="json") | ||||
| parser.add_argument( | |||||
| "regenerate_child_chunks", type=bool, required=False, nullable=True, default=False, location="json" | |||||
| ) | |||||
| args = parser.parse_args() | args = parser.parse_args() | ||||
| SegmentService.segment_create_args_validate(args, document) | SegmentService.segment_create_args_validate(args, document) | ||||
| segment = SegmentService.update_segment(args, segment, document, dataset) | |||||
| segment = SegmentService.update_segment(SegmentUpdateArgs(**args), segment, document, dataset) | |||||
| return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200 | return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200 | ||||
| @setup_required | @setup_required | ||||
| return {"job_id": job_id, "job_status": cache_result.decode()}, 200 | return {"job_id": job_id, "job_status": cache_result.decode()}, 200 | ||||
| class ChildChunkAddApi(Resource): | |||||
| @setup_required | |||||
| @login_required | |||||
| @account_initialization_required | |||||
| @cloud_edition_billing_resource_check("vector_space") | |||||
| @cloud_edition_billing_knowledge_limit_check("add_segment") | |||||
| def post(self, dataset_id, document_id, segment_id): | |||||
| # check dataset | |||||
| dataset_id = str(dataset_id) | |||||
| dataset = DatasetService.get_dataset(dataset_id) | |||||
| if not dataset: | |||||
| raise NotFound("Dataset not found.") | |||||
| # check document | |||||
| document_id = str(document_id) | |||||
| document = DocumentService.get_document(dataset_id, document_id) | |||||
| if not document: | |||||
| raise NotFound("Document not found.") | |||||
| # check segment | |||||
| segment_id = str(segment_id) | |||||
| segment = DocumentSegment.query.filter( | |||||
| DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id | |||||
| ).first() | |||||
| if not segment: | |||||
| raise NotFound("Segment not found.") | |||||
| if not current_user.is_editor: | |||||
| raise Forbidden() | |||||
| # check embedding model setting | |||||
| if dataset.indexing_technique == "high_quality": | |||||
| try: | |||||
| model_manager = ModelManager() | |||||
| model_manager.get_model_instance( | |||||
| tenant_id=current_user.current_tenant_id, | |||||
| provider=dataset.embedding_model_provider, | |||||
| model_type=ModelType.TEXT_EMBEDDING, | |||||
| model=dataset.embedding_model, | |||||
| ) | |||||
| except LLMBadRequestError: | |||||
| raise ProviderNotInitializeError( | |||||
| "No Embedding Model available. Please configure a valid provider " | |||||
| "in the Settings -> Model Provider." | |||||
| ) | |||||
| except ProviderTokenNotInitError as ex: | |||||
| raise ProviderNotInitializeError(ex.description) | |||||
| try: | |||||
| DatasetService.check_dataset_permission(dataset, current_user) | |||||
| except services.errors.account.NoPermissionError as e: | |||||
| raise Forbidden(str(e)) | |||||
| # validate args | |||||
| parser = reqparse.RequestParser() | |||||
| parser.add_argument("content", type=str, required=True, nullable=False, location="json") | |||||
| args = parser.parse_args() | |||||
| try: | |||||
| child_chunk = SegmentService.create_child_chunk(args.get("content"), segment, document, dataset) | |||||
| except ChildChunkIndexingServiceError as e: | |||||
| raise ChildChunkIndexingError(str(e)) | |||||
| return {"data": marshal(child_chunk, child_chunk_fields)}, 200 | |||||
| @setup_required | |||||
| @login_required | |||||
| @account_initialization_required | |||||
| def get(self, dataset_id, document_id, segment_id): | |||||
| # check dataset | |||||
| dataset_id = str(dataset_id) | |||||
| dataset = DatasetService.get_dataset(dataset_id) | |||||
| if not dataset: | |||||
| raise NotFound("Dataset not found.") | |||||
| # check user's model setting | |||||
| DatasetService.check_dataset_model_setting(dataset) | |||||
| # check document | |||||
| document_id = str(document_id) | |||||
| document = DocumentService.get_document(dataset_id, document_id) | |||||
| if not document: | |||||
| raise NotFound("Document not found.") | |||||
| # check segment | |||||
| segment_id = str(segment_id) | |||||
| segment = DocumentSegment.query.filter( | |||||
| DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id | |||||
| ).first() | |||||
| if not segment: | |||||
| raise NotFound("Segment not found.") | |||||
| parser = reqparse.RequestParser() | |||||
| parser.add_argument("limit", type=int, default=20, location="args") | |||||
| parser.add_argument("keyword", type=str, default=None, location="args") | |||||
| parser.add_argument("page", type=int, default=1, location="args") | |||||
| args = parser.parse_args() | |||||
| page = args["page"] | |||||
| limit = min(args["limit"], 100) | |||||
| keyword = args["keyword"] | |||||
| child_chunks = SegmentService.get_child_chunks(segment_id, document_id, dataset_id, page, limit, keyword) | |||||
| return { | |||||
| "data": marshal(child_chunks.items, child_chunk_fields), | |||||
| "total": child_chunks.total, | |||||
| "total_pages": child_chunks.pages, | |||||
| "page": page, | |||||
| "limit": limit, | |||||
| }, 200 | |||||
| @setup_required | |||||
| @login_required | |||||
| @account_initialization_required | |||||
| @cloud_edition_billing_resource_check("vector_space") | |||||
| def patch(self, dataset_id, document_id, segment_id): | |||||
| # check dataset | |||||
| dataset_id = str(dataset_id) | |||||
| dataset = DatasetService.get_dataset(dataset_id) | |||||
| if not dataset: | |||||
| raise NotFound("Dataset not found.") | |||||
| # check user's model setting | |||||
| DatasetService.check_dataset_model_setting(dataset) | |||||
| # check document | |||||
| document_id = str(document_id) | |||||
| document = DocumentService.get_document(dataset_id, document_id) | |||||
| if not document: | |||||
| raise NotFound("Document not found.") | |||||
| # check segment | |||||
| segment_id = str(segment_id) | |||||
| segment = DocumentSegment.query.filter( | |||||
| DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id | |||||
| ).first() | |||||
| if not segment: | |||||
| raise NotFound("Segment not found.") | |||||
| # The role of the current user in the ta table must be admin, owner, or editor | |||||
| if not current_user.is_editor: | |||||
| raise Forbidden() | |||||
| try: | |||||
| DatasetService.check_dataset_permission(dataset, current_user) | |||||
| except services.errors.account.NoPermissionError as e: | |||||
| raise Forbidden(str(e)) | |||||
| # validate args | |||||
| parser = reqparse.RequestParser() | |||||
| parser.add_argument("chunks", type=list, required=True, nullable=False, location="json") | |||||
| args = parser.parse_args() | |||||
| try: | |||||
| chunks = [ChildChunkUpdateArgs(**chunk) for chunk in args.get("chunks")] | |||||
| child_chunks = SegmentService.update_child_chunks(chunks, segment, document, dataset) | |||||
| except ChildChunkIndexingServiceError as e: | |||||
| raise ChildChunkIndexingError(str(e)) | |||||
| return {"data": marshal(child_chunks, child_chunk_fields)}, 200 | |||||
| class ChildChunkUpdateApi(Resource): | |||||
| @setup_required | |||||
| @login_required | |||||
| @account_initialization_required | |||||
| def delete(self, dataset_id, document_id, segment_id, child_chunk_id): | |||||
| # check dataset | |||||
| dataset_id = str(dataset_id) | |||||
| dataset = DatasetService.get_dataset(dataset_id) | |||||
| if not dataset: | |||||
| raise NotFound("Dataset not found.") | |||||
| # check user's model setting | |||||
| DatasetService.check_dataset_model_setting(dataset) | |||||
| # check document | |||||
| document_id = str(document_id) | |||||
| document = DocumentService.get_document(dataset_id, document_id) | |||||
| if not document: | |||||
| raise NotFound("Document not found.") | |||||
| # check segment | |||||
| segment_id = str(segment_id) | |||||
| segment = DocumentSegment.query.filter( | |||||
| DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id | |||||
| ).first() | |||||
| if not segment: | |||||
| raise NotFound("Segment not found.") | |||||
| # check child chunk | |||||
| child_chunk_id = str(child_chunk_id) | |||||
| child_chunk = ChildChunk.query.filter( | |||||
| ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id | |||||
| ).first() | |||||
| if not child_chunk: | |||||
| raise NotFound("Child chunk not found.") | |||||
| # The role of the current user in the ta table must be admin or owner | |||||
| if not current_user.is_editor: | |||||
| raise Forbidden() | |||||
| try: | |||||
| DatasetService.check_dataset_permission(dataset, current_user) | |||||
| except services.errors.account.NoPermissionError as e: | |||||
| raise Forbidden(str(e)) | |||||
| try: | |||||
| SegmentService.delete_child_chunk(child_chunk, dataset) | |||||
| except ChildChunkDeleteIndexServiceError as e: | |||||
| raise ChildChunkDeleteIndexError(str(e)) | |||||
| return {"result": "success"}, 200 | |||||
| @setup_required | |||||
| @login_required | |||||
| @account_initialization_required | |||||
| @cloud_edition_billing_resource_check("vector_space") | |||||
| def patch(self, dataset_id, document_id, segment_id, child_chunk_id): | |||||
| # check dataset | |||||
| dataset_id = str(dataset_id) | |||||
| dataset = DatasetService.get_dataset(dataset_id) | |||||
| if not dataset: | |||||
| raise NotFound("Dataset not found.") | |||||
| # check user's model setting | |||||
| DatasetService.check_dataset_model_setting(dataset) | |||||
| # check document | |||||
| document_id = str(document_id) | |||||
| document = DocumentService.get_document(dataset_id, document_id) | |||||
| if not document: | |||||
| raise NotFound("Document not found.") | |||||
| # check segment | |||||
| segment_id = str(segment_id) | |||||
| segment = DocumentSegment.query.filter( | |||||
| DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id | |||||
| ).first() | |||||
| if not segment: | |||||
| raise NotFound("Segment not found.") | |||||
| # check child chunk | |||||
| child_chunk_id = str(child_chunk_id) | |||||
| child_chunk = ChildChunk.query.filter( | |||||
| ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id | |||||
| ).first() | |||||
| if not child_chunk: | |||||
| raise NotFound("Child chunk not found.") | |||||
| # The role of the current user in the ta table must be admin or owner | |||||
| if not current_user.is_editor: | |||||
| raise Forbidden() | |||||
| try: | |||||
| DatasetService.check_dataset_permission(dataset, current_user) | |||||
| except services.errors.account.NoPermissionError as e: | |||||
| raise Forbidden(str(e)) | |||||
| # validate args | |||||
| parser = reqparse.RequestParser() | |||||
| parser.add_argument("content", type=str, required=True, nullable=False, location="json") | |||||
| args = parser.parse_args() | |||||
| try: | |||||
| child_chunk = SegmentService.update_child_chunk( | |||||
| args.get("content"), child_chunk, segment, document, dataset | |||||
| ) | |||||
| except ChildChunkIndexingServiceError as e: | |||||
| raise ChildChunkIndexingError(str(e)) | |||||
| return {"data": marshal(child_chunk, child_chunk_fields)}, 200 | |||||
| api.add_resource(DatasetDocumentSegmentListApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments") | api.add_resource(DatasetDocumentSegmentListApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments") | ||||
| api.add_resource(DatasetDocumentSegmentApi, "/datasets/<uuid:dataset_id>/segments/<uuid:segment_id>/<string:action>") | |||||
| api.add_resource( | |||||
| DatasetDocumentSegmentApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment/<string:action>" | |||||
| ) | |||||
| api.add_resource(DatasetDocumentSegmentAddApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment") | api.add_resource(DatasetDocumentSegmentAddApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment") | ||||
| api.add_resource( | api.add_resource( | ||||
| DatasetDocumentSegmentUpdateApi, | DatasetDocumentSegmentUpdateApi, | ||||
| "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/batch_import", | "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/batch_import", | ||||
| "/datasets/batch_import_status/<uuid:job_id>", | "/datasets/batch_import_status/<uuid:job_id>", | ||||
| ) | ) | ||||
| api.add_resource( | |||||
| ChildChunkAddApi, | |||||
| "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>/child_chunks", | |||||
| ) | |||||
| api.add_resource( | |||||
| ChildChunkUpdateApi, | |||||
| "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>/child_chunks/<uuid:child_chunk_id>", | |||||
| ) |
| error_code = "indexing_estimate_error" | error_code = "indexing_estimate_error" | ||||
| description = "Knowledge indexing estimate failed: {message}" | description = "Knowledge indexing estimate failed: {message}" | ||||
| code = 500 | code = 500 | ||||
| class ChildChunkIndexingError(BaseHTTPException): | |||||
| error_code = "child_chunk_indexing_error" | |||||
| description = "Create child chunk index failed: {message}" | |||||
| code = 500 | |||||
| class ChildChunkDeleteIndexError(BaseHTTPException): | |||||
| error_code = "child_chunk_delete_index_error" | |||||
| description = "Delete child chunk index failed: {message}" | |||||
| code = 500 |
| from fields.segment_fields import segment_fields | from fields.segment_fields import segment_fields | ||||
| from models.dataset import Dataset, DocumentSegment | from models.dataset import Dataset, DocumentSegment | ||||
| from services.dataset_service import DatasetService, DocumentService, SegmentService | from services.dataset_service import DatasetService, DocumentService, SegmentService | ||||
| from services.entities.knowledge_entities.knowledge_entities import SegmentUpdateArgs | |||||
| class SegmentApi(DatasetApiResource): | class SegmentApi(DatasetApiResource): | ||||
| args = parser.parse_args() | args = parser.parse_args() | ||||
| SegmentService.segment_create_args_validate(args["segment"], document) | SegmentService.segment_create_args_validate(args["segment"], document) | ||||
| segment = SegmentService.update_segment(args["segment"], segment, document, dataset) | |||||
| segment = SegmentService.update_segment(SegmentUpdateArgs(**args["segment"]), segment, document, dataset) | |||||
| return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200 | return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200 | ||||
| from typing import Optional | |||||
| from pydantic import BaseModel | |||||
| class PreviewDetail(BaseModel): | |||||
| content: str | |||||
| child_chunks: Optional[list[str]] = None | |||||
| class QAPreviewDetail(BaseModel): | |||||
| question: str | |||||
| answer: str | |||||
| class IndexingEstimate(BaseModel): | |||||
| total_segments: int | |||||
| preview: list[PreviewDetail] | |||||
| qa_preview: Optional[list[QAPreviewDetail]] = None |
| import uuid | import uuid | ||||
| from typing import Any, Optional, cast | from typing import Any, Optional, cast | ||||
| from flask import Flask, current_app | |||||
| from flask import current_app | |||||
| from flask_login import current_user # type: ignore | from flask_login import current_user # type: ignore | ||||
| from sqlalchemy.orm.exc import ObjectDeletedError | from sqlalchemy.orm.exc import ObjectDeletedError | ||||
| from configs import dify_config | from configs import dify_config | ||||
| from core.entities.knowledge_entities import IndexingEstimate, PreviewDetail, QAPreviewDetail | |||||
| from core.errors.error import ProviderTokenNotInitError | from core.errors.error import ProviderTokenNotInitError | ||||
| from core.llm_generator.llm_generator import LLMGenerator | |||||
| from core.model_manager import ModelInstance, ModelManager | from core.model_manager import ModelInstance, ModelManager | ||||
| from core.model_runtime.entities.model_entities import ModelType | from core.model_runtime.entities.model_entities import ModelType | ||||
| from core.rag.cleaner.clean_processor import CleanProcessor | from core.rag.cleaner.clean_processor import CleanProcessor | ||||
| from core.rag.datasource.keyword.keyword_factory import Keyword | from core.rag.datasource.keyword.keyword_factory import Keyword | ||||
| from core.rag.docstore.dataset_docstore import DatasetDocumentStore | from core.rag.docstore.dataset_docstore import DatasetDocumentStore | ||||
| from core.rag.extractor.entity.extract_setting import ExtractSetting | from core.rag.extractor.entity.extract_setting import ExtractSetting | ||||
| from core.rag.index_processor.constant.index_type import IndexType | |||||
| from core.rag.index_processor.index_processor_base import BaseIndexProcessor | from core.rag.index_processor.index_processor_base import BaseIndexProcessor | ||||
| from core.rag.index_processor.index_processor_factory import IndexProcessorFactory | from core.rag.index_processor.index_processor_factory import IndexProcessorFactory | ||||
| from core.rag.models.document import Document | |||||
| from core.rag.models.document import ChildDocument, Document | |||||
| from core.rag.splitter.fixed_text_splitter import ( | from core.rag.splitter.fixed_text_splitter import ( | ||||
| EnhanceRecursiveCharacterTextSplitter, | EnhanceRecursiveCharacterTextSplitter, | ||||
| FixedRecursiveCharacterTextSplitter, | FixedRecursiveCharacterTextSplitter, | ||||
| ) | ) | ||||
| from core.rag.splitter.text_splitter import TextSplitter | from core.rag.splitter.text_splitter import TextSplitter | ||||
| from core.tools.utils.text_processing_utils import remove_leading_symbols | |||||
| from core.tools.utils.web_reader_tool import get_image_upload_file_ids | from core.tools.utils.web_reader_tool import get_image_upload_file_ids | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from extensions.ext_redis import redis_client | from extensions.ext_redis import redis_client | ||||
| from extensions.ext_storage import storage | from extensions.ext_storage import storage | ||||
| from libs import helper | from libs import helper | ||||
| from models.dataset import Dataset, DatasetProcessRule, DocumentSegment | |||||
| from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment | |||||
| from models.dataset import Document as DatasetDocument | from models.dataset import Document as DatasetDocument | ||||
| from models.model import UploadFile | from models.model import UploadFile | ||||
| from services.feature_service import FeatureService | from services.feature_service import FeatureService | ||||
| for document_segment in document_segments: | for document_segment in document_segments: | ||||
| db.session.delete(document_segment) | db.session.delete(document_segment) | ||||
| if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: | |||||
| # delete child chunks | |||||
| db.session.query(ChildChunk).filter(ChildChunk.segment_id == document_segment.id).delete() | |||||
| db.session.commit() | db.session.commit() | ||||
| # get the process rule | # get the process rule | ||||
| processing_rule = ( | processing_rule = ( | ||||
| "dataset_id": document_segment.dataset_id, | "dataset_id": document_segment.dataset_id, | ||||
| }, | }, | ||||
| ) | ) | ||||
| if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: | |||||
| child_chunks = document_segment.child_chunks | |||||
| if child_chunks: | |||||
| child_documents = [] | |||||
| for child_chunk in child_chunks: | |||||
| child_document = ChildDocument( | |||||
| page_content=child_chunk.content, | |||||
| metadata={ | |||||
| "doc_id": child_chunk.index_node_id, | |||||
| "doc_hash": child_chunk.index_node_hash, | |||||
| "document_id": document_segment.document_id, | |||||
| "dataset_id": document_segment.dataset_id, | |||||
| }, | |||||
| ) | |||||
| child_documents.append(child_document) | |||||
| document.children = child_documents | |||||
| documents.append(document) | documents.append(document) | ||||
| # build index | # build index | ||||
| doc_language: str = "English", | doc_language: str = "English", | ||||
| dataset_id: Optional[str] = None, | dataset_id: Optional[str] = None, | ||||
| indexing_technique: str = "economy", | indexing_technique: str = "economy", | ||||
| ) -> dict: | |||||
| ) -> IndexingEstimate: | |||||
| """ | """ | ||||
| Estimate the indexing for the document. | Estimate the indexing for the document. | ||||
| """ | """ | ||||
| tenant_id=tenant_id, | tenant_id=tenant_id, | ||||
| model_type=ModelType.TEXT_EMBEDDING, | model_type=ModelType.TEXT_EMBEDDING, | ||||
| ) | ) | ||||
| preview_texts: list[str] = [] | |||||
| preview_texts = [] | |||||
| total_segments = 0 | total_segments = 0 | ||||
| index_type = doc_form | index_type = doc_form | ||||
| index_processor = IndexProcessorFactory(index_type).init_index_processor() | index_processor = IndexProcessorFactory(index_type).init_index_processor() | ||||
| all_text_docs = [] | |||||
| for extract_setting in extract_settings: | for extract_setting in extract_settings: | ||||
| # extract | # extract | ||||
| text_docs = index_processor.extract(extract_setting, process_rule_mode=tmp_processing_rule["mode"]) | |||||
| all_text_docs.extend(text_docs) | |||||
| processing_rule = DatasetProcessRule( | processing_rule = DatasetProcessRule( | ||||
| mode=tmp_processing_rule["mode"], rules=json.dumps(tmp_processing_rule["rules"]) | mode=tmp_processing_rule["mode"], rules=json.dumps(tmp_processing_rule["rules"]) | ||||
| ) | ) | ||||
| # get splitter | |||||
| splitter = self._get_splitter(processing_rule, embedding_model_instance) | |||||
| # split to documents | |||||
| documents = self._split_to_documents_for_estimate( | |||||
| text_docs=text_docs, splitter=splitter, processing_rule=processing_rule | |||||
| text_docs = index_processor.extract(extract_setting, process_rule_mode=tmp_processing_rule["mode"]) | |||||
| documents = index_processor.transform( | |||||
| text_docs, | |||||
| embedding_model_instance=embedding_model_instance, | |||||
| process_rule=processing_rule.to_dict(), | |||||
| tenant_id=current_user.current_tenant_id, | |||||
| doc_language=doc_language, | |||||
| preview=True, | |||||
| ) | ) | ||||
| total_segments += len(documents) | total_segments += len(documents) | ||||
| for document in documents: | for document in documents: | ||||
| if len(preview_texts) < 5: | |||||
| preview_texts.append(document.page_content) | |||||
| if len(preview_texts) < 10: | |||||
| if doc_form and doc_form == "qa_model": | |||||
| preview_detail = QAPreviewDetail( | |||||
| question=document.page_content, answer=document.metadata.get("answer") | |||||
| ) | |||||
| preview_texts.append(preview_detail) | |||||
| else: | |||||
| preview_detail = PreviewDetail(content=document.page_content) | |||||
| if document.children: | |||||
| preview_detail.child_chunks = [child.page_content for child in document.children] | |||||
| preview_texts.append(preview_detail) | |||||
| # delete image files and related db records | # delete image files and related db records | ||||
| image_upload_file_ids = get_image_upload_file_ids(document.page_content) | image_upload_file_ids = get_image_upload_file_ids(document.page_content) | ||||
| db.session.delete(image_file) | db.session.delete(image_file) | ||||
| if doc_form and doc_form == "qa_model": | if doc_form and doc_form == "qa_model": | ||||
| if len(preview_texts) > 0: | |||||
| # qa model document | |||||
| response = LLMGenerator.generate_qa_document( | |||||
| current_user.current_tenant_id, preview_texts[0], doc_language | |||||
| ) | |||||
| document_qa_list = self.format_split_text(response) | |||||
| return {"total_segments": total_segments * 20, "qa_preview": document_qa_list, "preview": preview_texts} | |||||
| return {"total_segments": total_segments, "preview": preview_texts} | |||||
| return IndexingEstimate(total_segments=total_segments * 20, qa_preview=preview_texts, preview=[]) | |||||
| return IndexingEstimate(total_segments=total_segments, preview=preview_texts) | |||||
| def _extract( | def _extract( | ||||
| self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict | self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict | ||||
| @staticmethod | @staticmethod | ||||
| def _get_splitter( | def _get_splitter( | ||||
| processing_rule: DatasetProcessRule, embedding_model_instance: Optional[ModelInstance] | |||||
| processing_rule_mode: str, | |||||
| max_tokens: int, | |||||
| chunk_overlap: int, | |||||
| separator: str, | |||||
| embedding_model_instance: Optional[ModelInstance], | |||||
| ) -> TextSplitter: | ) -> TextSplitter: | ||||
| """ | """ | ||||
| Get the NodeParser object according to the processing rule. | Get the NodeParser object according to the processing rule. | ||||
| """ | """ | ||||
| character_splitter: TextSplitter | |||||
| if processing_rule.mode == "custom": | |||||
| if processing_rule_mode in ["custom", "hierarchical"]: | |||||
| # The user-defined segmentation rule | # The user-defined segmentation rule | ||||
| rules = json.loads(processing_rule.rules) | |||||
| segmentation = rules["segmentation"] | |||||
| max_segmentation_tokens_length = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH | max_segmentation_tokens_length = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH | ||||
| if segmentation["max_tokens"] < 50 or segmentation["max_tokens"] > max_segmentation_tokens_length: | |||||
| if max_tokens < 50 or max_tokens > max_segmentation_tokens_length: | |||||
| raise ValueError(f"Custom segment length should be between 50 and {max_segmentation_tokens_length}.") | raise ValueError(f"Custom segment length should be between 50 and {max_segmentation_tokens_length}.") | ||||
| separator = segmentation["separator"] | |||||
| if separator: | if separator: | ||||
| separator = separator.replace("\\n", "\n") | separator = separator.replace("\\n", "\n") | ||||
| if segmentation.get("chunk_overlap"): | |||||
| chunk_overlap = segmentation["chunk_overlap"] | |||||
| else: | |||||
| chunk_overlap = 0 | |||||
| character_splitter = FixedRecursiveCharacterTextSplitter.from_encoder( | character_splitter = FixedRecursiveCharacterTextSplitter.from_encoder( | ||||
| chunk_size=segmentation["max_tokens"], | |||||
| chunk_size=max_tokens, | |||||
| chunk_overlap=chunk_overlap, | chunk_overlap=chunk_overlap, | ||||
| fixed_separator=separator, | fixed_separator=separator, | ||||
| separators=["\n\n", "。", ". ", " ", ""], | separators=["\n\n", "。", ". ", " ", ""], | ||||
| return character_splitter | return character_splitter | ||||
| def _step_split( | |||||
| self, | |||||
| text_docs: list[Document], | |||||
| splitter: TextSplitter, | |||||
| dataset: Dataset, | |||||
| dataset_document: DatasetDocument, | |||||
| processing_rule: DatasetProcessRule, | |||||
| ) -> list[Document]: | |||||
| """ | |||||
| Split the text documents into documents and save them to the document segment. | |||||
| """ | |||||
| documents = self._split_to_documents( | |||||
| text_docs=text_docs, | |||||
| splitter=splitter, | |||||
| processing_rule=processing_rule, | |||||
| tenant_id=dataset.tenant_id, | |||||
| document_form=dataset_document.doc_form, | |||||
| document_language=dataset_document.doc_language, | |||||
| ) | |||||
| # save node to document segment | |||||
| doc_store = DatasetDocumentStore( | |||||
| dataset=dataset, user_id=dataset_document.created_by, document_id=dataset_document.id | |||||
| ) | |||||
| # add document segments | |||||
| doc_store.add_documents(documents) | |||||
| # update document status to indexing | |||||
| cur_time = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) | |||||
| self._update_document_index_status( | |||||
| document_id=dataset_document.id, | |||||
| after_indexing_status="indexing", | |||||
| extra_update_params={ | |||||
| DatasetDocument.cleaning_completed_at: cur_time, | |||||
| DatasetDocument.splitting_completed_at: cur_time, | |||||
| }, | |||||
| ) | |||||
| # update segment status to indexing | |||||
| self._update_segments_by_document( | |||||
| dataset_document_id=dataset_document.id, | |||||
| update_params={ | |||||
| DocumentSegment.status: "indexing", | |||||
| DocumentSegment.indexing_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), | |||||
| }, | |||||
| ) | |||||
| return documents | |||||
| def _split_to_documents( | |||||
| self, | |||||
| text_docs: list[Document], | |||||
| splitter: TextSplitter, | |||||
| processing_rule: DatasetProcessRule, | |||||
| tenant_id: str, | |||||
| document_form: str, | |||||
| document_language: str, | |||||
| ) -> list[Document]: | |||||
| """ | |||||
| Split the text documents into nodes. | |||||
| """ | |||||
| all_documents: list[Document] = [] | |||||
| all_qa_documents: list[Document] = [] | |||||
| for text_doc in text_docs: | |||||
| # document clean | |||||
| document_text = self._document_clean(text_doc.page_content, processing_rule) | |||||
| text_doc.page_content = document_text | |||||
| # parse document to nodes | |||||
| documents = splitter.split_documents([text_doc]) | |||||
| split_documents = [] | |||||
| for document_node in documents: | |||||
| if document_node.page_content.strip(): | |||||
| if document_node.metadata is not None: | |||||
| doc_id = str(uuid.uuid4()) | |||||
| hash = helper.generate_text_hash(document_node.page_content) | |||||
| document_node.metadata["doc_id"] = doc_id | |||||
| document_node.metadata["doc_hash"] = hash | |||||
| # delete Splitter character | |||||
| page_content = document_node.page_content | |||||
| document_node.page_content = remove_leading_symbols(page_content) | |||||
| if document_node.page_content: | |||||
| split_documents.append(document_node) | |||||
| all_documents.extend(split_documents) | |||||
| # processing qa document | |||||
| if document_form == "qa_model": | |||||
| for i in range(0, len(all_documents), 10): | |||||
| threads = [] | |||||
| sub_documents = all_documents[i : i + 10] | |||||
| for doc in sub_documents: | |||||
| document_format_thread = threading.Thread( | |||||
| target=self.format_qa_document, | |||||
| kwargs={ | |||||
| "flask_app": current_app._get_current_object(), # type: ignore | |||||
| "tenant_id": tenant_id, | |||||
| "document_node": doc, | |||||
| "all_qa_documents": all_qa_documents, | |||||
| "document_language": document_language, | |||||
| }, | |||||
| ) | |||||
| threads.append(document_format_thread) | |||||
| document_format_thread.start() | |||||
| for thread in threads: | |||||
| thread.join() | |||||
| return all_qa_documents | |||||
| return all_documents | |||||
| def format_qa_document(self, flask_app: Flask, tenant_id: str, document_node, all_qa_documents, document_language): | |||||
| format_documents = [] | |||||
| if document_node.page_content is None or not document_node.page_content.strip(): | |||||
| return | |||||
| with flask_app.app_context(): | |||||
| try: | |||||
| # qa model document | |||||
| response = LLMGenerator.generate_qa_document(tenant_id, document_node.page_content, document_language) | |||||
| document_qa_list = self.format_split_text(response) | |||||
| qa_documents = [] | |||||
| for result in document_qa_list: | |||||
| qa_document = Document( | |||||
| page_content=result["question"], metadata=document_node.metadata.model_copy() | |||||
| ) | |||||
| if qa_document.metadata is not None: | |||||
| doc_id = str(uuid.uuid4()) | |||||
| hash = helper.generate_text_hash(result["question"]) | |||||
| qa_document.metadata["answer"] = result["answer"] | |||||
| qa_document.metadata["doc_id"] = doc_id | |||||
| qa_document.metadata["doc_hash"] = hash | |||||
| qa_documents.append(qa_document) | |||||
| format_documents.extend(qa_documents) | |||||
| except Exception as e: | |||||
| logging.exception("Failed to format qa document") | |||||
| all_qa_documents.extend(format_documents) | |||||
| def _split_to_documents_for_estimate( | def _split_to_documents_for_estimate( | ||||
| self, text_docs: list[Document], splitter: TextSplitter, processing_rule: DatasetProcessRule | self, text_docs: list[Document], splitter: TextSplitter, processing_rule: DatasetProcessRule | ||||
| ) -> list[Document]: | ) -> list[Document]: | ||||
| return document_text | return document_text | ||||
| @staticmethod | @staticmethod | ||||
| def format_split_text(text): | |||||
| def format_split_text(text: str) -> list[QAPreviewDetail]: | |||||
| regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q\d+:|$)" | regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q\d+:|$)" | ||||
| matches = re.findall(regex, text, re.UNICODE) | matches = re.findall(regex, text, re.UNICODE) | ||||
| return [{"question": q, "answer": re.sub(r"\n\s*", "\n", a.strip())} for q, a in matches if q and a] | |||||
| return [QAPreviewDetail(question=q, answer=re.sub(r"\n\s*", "\n", a.strip())) for q, a in matches if q and a] | |||||
| def _load( | def _load( | ||||
| self, | self, | ||||
| indexing_start_at = time.perf_counter() | indexing_start_at = time.perf_counter() | ||||
| tokens = 0 | tokens = 0 | ||||
| chunk_size = 10 | chunk_size = 10 | ||||
| if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX: | |||||
| # create keyword index | |||||
| create_keyword_thread = threading.Thread( | |||||
| target=self._process_keyword_index, | |||||
| args=(current_app._get_current_object(), dataset.id, dataset_document.id, documents), | |||||
| ) | |||||
| create_keyword_thread.start() | |||||
| # create keyword index | |||||
| create_keyword_thread = threading.Thread( | |||||
| target=self._process_keyword_index, | |||||
| args=(current_app._get_current_object(), dataset.id, dataset_document.id, documents), # type: ignore | |||||
| ) | |||||
| create_keyword_thread.start() | |||||
| if dataset.indexing_technique == "high_quality": | if dataset.indexing_technique == "high_quality": | ||||
| with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor: | with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor: | ||||
| futures = [] | futures = [] | ||||
| for future in futures: | for future in futures: | ||||
| tokens += future.result() | tokens += future.result() | ||||
| create_keyword_thread.join() | |||||
| if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX: | |||||
| create_keyword_thread.join() | |||||
| indexing_end_at = time.perf_counter() | indexing_end_at = time.perf_counter() | ||||
| # update document status to completed | # update document status to completed | ||||
| DocumentSegment.query.filter_by(document_id=dataset_document_id).update(update_params) | DocumentSegment.query.filter_by(document_id=dataset_document_id).update(update_params) | ||||
| db.session.commit() | db.session.commit() | ||||
| @staticmethod | |||||
| def batch_add_segments(segments: list[DocumentSegment], dataset: Dataset): | |||||
| """ | |||||
| Batch add segments index processing | |||||
| """ | |||||
| documents = [] | |||||
| for segment in segments: | |||||
| document = Document( | |||||
| page_content=segment.content, | |||||
| metadata={ | |||||
| "doc_id": segment.index_node_id, | |||||
| "doc_hash": segment.index_node_hash, | |||||
| "document_id": segment.document_id, | |||||
| "dataset_id": segment.dataset_id, | |||||
| }, | |||||
| ) | |||||
| documents.append(document) | |||||
| # save vector index | |||||
| index_type = dataset.doc_form | |||||
| index_processor = IndexProcessorFactory(index_type).init_index_processor() | |||||
| index_processor.load(dataset, documents) | |||||
| def _transform( | def _transform( | ||||
| self, | self, | ||||
| index_processor: BaseIndexProcessor, | index_processor: BaseIndexProcessor, | ||||
| ) | ) | ||||
| # add document segments | # add document segments | ||||
| doc_store.add_documents(documents) | |||||
| doc_store.add_documents(docs=documents, save_child=dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX) | |||||
| # update document status to indexing | # update document status to indexing | ||||
| cur_time = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) | cur_time = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) |
| from core.rag.data_post_processor.data_post_processor import DataPostProcessor | from core.rag.data_post_processor.data_post_processor import DataPostProcessor | ||||
| from core.rag.datasource.keyword.keyword_factory import Keyword | from core.rag.datasource.keyword.keyword_factory import Keyword | ||||
| from core.rag.datasource.vdb.vector_factory import Vector | from core.rag.datasource.vdb.vector_factory import Vector | ||||
| from core.rag.embedding.retrieval import RetrievalSegments | |||||
| from core.rag.index_processor.constant.index_type import IndexType | |||||
| from core.rag.models.document import Document | from core.rag.models.document import Document | ||||
| from core.rag.rerank.rerank_type import RerankMode | from core.rag.rerank.rerank_type import RerankMode | ||||
| from core.rag.retrieval.retrieval_methods import RetrievalMethod | from core.rag.retrieval.retrieval_methods import RetrievalMethod | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from models.dataset import Dataset | |||||
| from models.dataset import ChildChunk, Dataset, DocumentSegment | |||||
| from models.dataset import Document as DatasetDocument | |||||
| from services.external_knowledge_service import ExternalDatasetService | from services.external_knowledge_service import ExternalDatasetService | ||||
| default_retrieval_model = { | default_retrieval_model = { | ||||
| @staticmethod | @staticmethod | ||||
| def escape_query_for_search(query: str) -> str: | def escape_query_for_search(query: str) -> str: | ||||
| return query.replace('"', '\\"') | return query.replace('"', '\\"') | ||||
| @staticmethod | |||||
| def format_retrieval_documents(documents: list[Document]) -> list[RetrievalSegments]: | |||||
| records = [] | |||||
| include_segment_ids = [] | |||||
| segment_child_map = {} | |||||
| for document in documents: | |||||
| document_id = document.metadata["document_id"] | |||||
| dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == document_id).first() | |||||
| if dataset_document and dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: | |||||
| child_index_node_id = document.metadata["doc_id"] | |||||
| result = ( | |||||
| db.session.query(ChildChunk, DocumentSegment) | |||||
| .join(DocumentSegment, ChildChunk.segment_id == DocumentSegment.id) | |||||
| .filter( | |||||
| ChildChunk.index_node_id == child_index_node_id, | |||||
| DocumentSegment.dataset_id == dataset_document.dataset_id, | |||||
| DocumentSegment.enabled == True, | |||||
| DocumentSegment.status == "completed", | |||||
| ) | |||||
| .first() | |||||
| ) | |||||
| if result: | |||||
| child_chunk, segment = result | |||||
| if not segment: | |||||
| continue | |||||
| if segment.id not in include_segment_ids: | |||||
| include_segment_ids.append(segment.id) | |||||
| child_chunk_detail = { | |||||
| "id": child_chunk.id, | |||||
| "content": child_chunk.content, | |||||
| "position": child_chunk.position, | |||||
| "score": document.metadata.get("score", 0.0), | |||||
| } | |||||
| map_detail = { | |||||
| "max_score": document.metadata.get("score", 0.0), | |||||
| "child_chunks": [child_chunk_detail], | |||||
| } | |||||
| segment_child_map[segment.id] = map_detail | |||||
| record = { | |||||
| "segment": segment, | |||||
| } | |||||
| records.append(record) | |||||
| else: | |||||
| child_chunk_detail = { | |||||
| "id": child_chunk.id, | |||||
| "content": child_chunk.content, | |||||
| "position": child_chunk.position, | |||||
| "score": document.metadata.get("score", 0.0), | |||||
| } | |||||
| segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail) | |||||
| segment_child_map[segment.id]["max_score"] = max( | |||||
| segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0) | |||||
| ) | |||||
| else: | |||||
| continue | |||||
| else: | |||||
| index_node_id = document.metadata["doc_id"] | |||||
| segment = ( | |||||
| db.session.query(DocumentSegment) | |||||
| .filter( | |||||
| DocumentSegment.dataset_id == dataset_document.dataset_id, | |||||
| DocumentSegment.enabled == True, | |||||
| DocumentSegment.status == "completed", | |||||
| DocumentSegment.index_node_id == index_node_id, | |||||
| ) | |||||
| .first() | |||||
| ) | |||||
| if not segment: | |||||
| continue | |||||
| include_segment_ids.append(segment.id) | |||||
| record = { | |||||
| "segment": segment, | |||||
| "score": document.metadata.get("score", None), | |||||
| } | |||||
| records.append(record) | |||||
| for record in records: | |||||
| if record["segment"].id in segment_child_map: | |||||
| record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks", None) | |||||
| record["score"] = segment_child_map[record["segment"].id]["max_score"] | |||||
| return [RetrievalSegments(**record) for record in records] |
| from core.model_runtime.entities.model_entities import ModelType | from core.model_runtime.entities.model_entities import ModelType | ||||
| from core.rag.models.document import Document | from core.rag.models.document import Document | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from models.dataset import Dataset, DocumentSegment | |||||
| from models.dataset import ChildChunk, Dataset, DocumentSegment | |||||
| class DatasetDocumentStore: | class DatasetDocumentStore: | ||||
| return output | return output | ||||
| def add_documents(self, docs: Sequence[Document], allow_update: bool = True) -> None: | |||||
| def add_documents(self, docs: Sequence[Document], allow_update: bool = True, save_child: bool = False) -> None: | |||||
| max_position = ( | max_position = ( | ||||
| db.session.query(func.max(DocumentSegment.position)) | db.session.query(func.max(DocumentSegment.position)) | ||||
| .filter(DocumentSegment.document_id == self._document_id) | .filter(DocumentSegment.document_id == self._document_id) | ||||
| segment_document.answer = doc.metadata.pop("answer", "") | segment_document.answer = doc.metadata.pop("answer", "") | ||||
| db.session.add(segment_document) | db.session.add(segment_document) | ||||
| db.session.flush() | |||||
| if save_child: | |||||
| for postion, child in enumerate(doc.children, start=1): | |||||
| child_segment = ChildChunk( | |||||
| tenant_id=self._dataset.tenant_id, | |||||
| dataset_id=self._dataset.id, | |||||
| document_id=self._document_id, | |||||
| segment_id=segment_document.id, | |||||
| position=postion, | |||||
| index_node_id=child.metadata["doc_id"], | |||||
| index_node_hash=child.metadata["doc_hash"], | |||||
| content=child.page_content, | |||||
| word_count=len(child.page_content), | |||||
| type="automatic", | |||||
| created_by=self._user_id, | |||||
| ) | |||||
| db.session.add(child_segment) | |||||
| else: | else: | ||||
| segment_document.content = doc.page_content | segment_document.content = doc.page_content | ||||
| if doc.metadata.get("answer"): | if doc.metadata.get("answer"): | ||||
| segment_document.index_node_hash = doc.metadata["doc_hash"] | segment_document.index_node_hash = doc.metadata["doc_hash"] | ||||
| segment_document.word_count = len(doc.page_content) | segment_document.word_count = len(doc.page_content) | ||||
| segment_document.tokens = tokens | segment_document.tokens = tokens | ||||
| if save_child and doc.children: | |||||
| # delete the existing child chunks | |||||
| db.session.query(ChildChunk).filter( | |||||
| ChildChunk.tenant_id == self._dataset.tenant_id, | |||||
| ChildChunk.dataset_id == self._dataset.id, | |||||
| ChildChunk.document_id == self._document_id, | |||||
| ChildChunk.segment_id == segment_document.id, | |||||
| ).delete() | |||||
| # add new child chunks | |||||
| for position, child in enumerate(doc.children, start=1): | |||||
| child_segment = ChildChunk( | |||||
| tenant_id=self._dataset.tenant_id, | |||||
| dataset_id=self._dataset.id, | |||||
| document_id=self._document_id, | |||||
| segment_id=segment_document.id, | |||||
| position=position, | |||||
| index_node_id=child.metadata["doc_id"], | |||||
| index_node_hash=child.metadata["doc_hash"], | |||||
| content=child.page_content, | |||||
| word_count=len(child.page_content), | |||||
| type="automatic", | |||||
| created_by=self._user_id, | |||||
| ) | |||||
| db.session.add(child_segment) | |||||
| db.session.commit() | db.session.commit() | ||||
| from typing import Optional | |||||
| from pydantic import BaseModel | |||||
| from models.dataset import DocumentSegment | |||||
| class RetrievalChildChunk(BaseModel): | |||||
| """Retrieval segments.""" | |||||
| id: str | |||||
| content: str | |||||
| score: float | |||||
| position: int | |||||
| class RetrievalSegments(BaseModel): | |||||
| """Retrieval segments.""" | |||||
| model_config = {"arbitrary_types_allowed": True} | |||||
| segment: DocumentSegment | |||||
| child_chunks: Optional[list[RetrievalChildChunk]] = None | |||||
| score: Optional[float] = None |
| from core.rag.extractor.unstructured.unstructured_msg_extractor import UnstructuredMsgExtractor | from core.rag.extractor.unstructured.unstructured_msg_extractor import UnstructuredMsgExtractor | ||||
| from core.rag.extractor.unstructured.unstructured_ppt_extractor import UnstructuredPPTExtractor | from core.rag.extractor.unstructured.unstructured_ppt_extractor import UnstructuredPPTExtractor | ||||
| from core.rag.extractor.unstructured.unstructured_pptx_extractor import UnstructuredPPTXExtractor | from core.rag.extractor.unstructured.unstructured_pptx_extractor import UnstructuredPPTXExtractor | ||||
| from core.rag.extractor.unstructured.unstructured_text_extractor import UnstructuredTextExtractor | |||||
| from core.rag.extractor.unstructured.unstructured_xml_extractor import UnstructuredXmlExtractor | from core.rag.extractor.unstructured.unstructured_xml_extractor import UnstructuredXmlExtractor | ||||
| from core.rag.extractor.word_extractor import WordExtractor | from core.rag.extractor.word_extractor import WordExtractor | ||||
| from core.rag.models.document import Document | from core.rag.models.document import Document | ||||
| extractor = UnstructuredEpubExtractor(file_path, unstructured_api_url, unstructured_api_key) | extractor = UnstructuredEpubExtractor(file_path, unstructured_api_url, unstructured_api_key) | ||||
| else: | else: | ||||
| # txt | # txt | ||||
| extractor = ( | |||||
| UnstructuredTextExtractor(file_path, unstructured_api_url) | |||||
| if is_automatic | |||||
| else TextExtractor(file_path, autodetect_encoding=True) | |||||
| ) | |||||
| extractor = TextExtractor(file_path, autodetect_encoding=True) | |||||
| else: | else: | ||||
| if file_extension in {".xlsx", ".xls"}: | if file_extension in {".xlsx", ".xls"}: | ||||
| extractor = ExcelExtractor(file_path) | extractor = ExcelExtractor(file_path) |
| if isinstance(element.tag, str) and element.tag.endswith("p"): # paragraph | if isinstance(element.tag, str) and element.tag.endswith("p"): # paragraph | ||||
| para = paragraphs.pop(0) | para = paragraphs.pop(0) | ||||
| parsed_paragraph = parse_paragraph(para) | parsed_paragraph = parse_paragraph(para) | ||||
| if parsed_paragraph: | |||||
| if parsed_paragraph.strip(): | |||||
| content.append(parsed_paragraph) | content.append(parsed_paragraph) | ||||
| else: | |||||
| content.append("\n") | |||||
| elif isinstance(element.tag, str) and element.tag.endswith("tbl"): # table | elif isinstance(element.tag, str) and element.tag.endswith("tbl"): # table | ||||
| table = tables.pop(0) | table = tables.pop(0) | ||||
| content.append(self._table_to_markdown(table, image_map)) | content.append(self._table_to_markdown(table, image_map)) |
| from enum import Enum | from enum import Enum | ||||
| class IndexType(Enum): | |||||
| class IndexType(str, Enum): | |||||
| PARAGRAPH_INDEX = "text_model" | PARAGRAPH_INDEX = "text_model" | ||||
| QA_INDEX = "qa_model" | QA_INDEX = "qa_model" | ||||
| PARENT_CHILD_INDEX = "parent_child_index" | |||||
| SUMMARY_INDEX = "summary_index" | |||||
| PARENT_CHILD_INDEX = "hierarchical_model" |
| raise NotImplementedError | raise NotImplementedError | ||||
| @abstractmethod | @abstractmethod | ||||
| def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True): | |||||
| def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True, **kwargs): | |||||
| raise NotImplementedError | raise NotImplementedError | ||||
| def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True): | |||||
| def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True, **kwargs): | |||||
| raise NotImplementedError | raise NotImplementedError | ||||
| @abstractmethod | @abstractmethod | ||||
| ) -> list[Document]: | ) -> list[Document]: | ||||
| raise NotImplementedError | raise NotImplementedError | ||||
| def _get_splitter(self, processing_rule: dict, embedding_model_instance: Optional[ModelInstance]) -> TextSplitter: | |||||
| def _get_splitter( | |||||
| self, | |||||
| processing_rule_mode: str, | |||||
| max_tokens: int, | |||||
| chunk_overlap: int, | |||||
| separator: str, | |||||
| embedding_model_instance: Optional[ModelInstance], | |||||
| ) -> TextSplitter: | |||||
| """ | """ | ||||
| Get the NodeParser object according to the processing rule. | Get the NodeParser object according to the processing rule. | ||||
| """ | """ | ||||
| character_splitter: TextSplitter | |||||
| if processing_rule["mode"] == "custom": | |||||
| if processing_rule_mode in ["custom", "hierarchical"]: | |||||
| # The user-defined segmentation rule | # The user-defined segmentation rule | ||||
| rules = processing_rule["rules"] | |||||
| segmentation = rules["segmentation"] | |||||
| max_segmentation_tokens_length = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH | max_segmentation_tokens_length = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH | ||||
| if segmentation["max_tokens"] < 50 or segmentation["max_tokens"] > max_segmentation_tokens_length: | |||||
| if max_tokens < 50 or max_tokens > max_segmentation_tokens_length: | |||||
| raise ValueError(f"Custom segment length should be between 50 and {max_segmentation_tokens_length}.") | raise ValueError(f"Custom segment length should be between 50 and {max_segmentation_tokens_length}.") | ||||
| separator = segmentation["separator"] | |||||
| if separator: | if separator: | ||||
| separator = separator.replace("\\n", "\n") | separator = separator.replace("\\n", "\n") | ||||
| character_splitter = FixedRecursiveCharacterTextSplitter.from_encoder( | character_splitter = FixedRecursiveCharacterTextSplitter.from_encoder( | ||||
| chunk_size=segmentation["max_tokens"], | |||||
| chunk_overlap=segmentation.get("chunk_overlap", 0) or 0, | |||||
| chunk_size=max_tokens, | |||||
| chunk_overlap=chunk_overlap, | |||||
| fixed_separator=separator, | fixed_separator=separator, | ||||
| separators=["\n\n", "。", ". ", " ", ""], | separators=["\n\n", "。", ". ", " ", ""], | ||||
| embedding_model_instance=embedding_model_instance, | embedding_model_instance=embedding_model_instance, |
| from core.rag.index_processor.constant.index_type import IndexType | from core.rag.index_processor.constant.index_type import IndexType | ||||
| from core.rag.index_processor.index_processor_base import BaseIndexProcessor | from core.rag.index_processor.index_processor_base import BaseIndexProcessor | ||||
| from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor | from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor | ||||
| from core.rag.index_processor.processor.parent_child_index_processor import ParentChildIndexProcessor | |||||
| from core.rag.index_processor.processor.qa_index_processor import QAIndexProcessor | from core.rag.index_processor.processor.qa_index_processor import QAIndexProcessor | ||||
| if not self._index_type: | if not self._index_type: | ||||
| raise ValueError("Index type must be specified.") | raise ValueError("Index type must be specified.") | ||||
| if self._index_type == IndexType.PARAGRAPH_INDEX.value: | |||||
| if self._index_type == IndexType.PARAGRAPH_INDEX: | |||||
| return ParagraphIndexProcessor() | return ParagraphIndexProcessor() | ||||
| elif self._index_type == IndexType.QA_INDEX.value: | |||||
| elif self._index_type == IndexType.QA_INDEX: | |||||
| return QAIndexProcessor() | return QAIndexProcessor() | ||||
| elif self._index_type == IndexType.PARENT_CHILD_INDEX: | |||||
| return ParentChildIndexProcessor() | |||||
| else: | else: | ||||
| raise ValueError(f"Index type {self._index_type} is not supported.") | raise ValueError(f"Index type {self._index_type} is not supported.") |
| from core.rag.models.document import Document | from core.rag.models.document import Document | ||||
| from core.tools.utils.text_processing_utils import remove_leading_symbols | from core.tools.utils.text_processing_utils import remove_leading_symbols | ||||
| from libs import helper | from libs import helper | ||||
| from models.dataset import Dataset | |||||
| from models.dataset import Dataset, DatasetProcessRule | |||||
| from services.entities.knowledge_entities.knowledge_entities import Rule | |||||
| class ParagraphIndexProcessor(BaseIndexProcessor): | class ParagraphIndexProcessor(BaseIndexProcessor): | ||||
| def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]: | def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]: | ||||
| text_docs = ExtractProcessor.extract( | text_docs = ExtractProcessor.extract( | ||||
| extract_setting=extract_setting, is_automatic=kwargs.get("process_rule_mode") == "automatic" | |||||
| extract_setting=extract_setting, | |||||
| is_automatic=( | |||||
| kwargs.get("process_rule_mode") == "automatic" or kwargs.get("process_rule_mode") == "hierarchical" | |||||
| ), | |||||
| ) | ) | ||||
| return text_docs | return text_docs | ||||
| def transform(self, documents: list[Document], **kwargs) -> list[Document]: | def transform(self, documents: list[Document], **kwargs) -> list[Document]: | ||||
| process_rule = kwargs.get("process_rule") | |||||
| if process_rule.get("mode") == "automatic": | |||||
| automatic_rule = DatasetProcessRule.AUTOMATIC_RULES | |||||
| rules = Rule(**automatic_rule) | |||||
| else: | |||||
| rules = Rule(**process_rule.get("rules")) | |||||
| # Split the text documents into nodes. | # Split the text documents into nodes. | ||||
| splitter = self._get_splitter( | splitter = self._get_splitter( | ||||
| processing_rule=kwargs.get("process_rule", {}), | |||||
| processing_rule_mode=process_rule.get("mode"), | |||||
| max_tokens=rules.segmentation.max_tokens, | |||||
| chunk_overlap=rules.segmentation.chunk_overlap, | |||||
| separator=rules.segmentation.separator, | |||||
| embedding_model_instance=kwargs.get("embedding_model_instance"), | embedding_model_instance=kwargs.get("embedding_model_instance"), | ||||
| ) | ) | ||||
| all_documents = [] | all_documents = [] | ||||
| all_documents.extend(split_documents) | all_documents.extend(split_documents) | ||||
| return all_documents | return all_documents | ||||
| def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True): | |||||
| def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True, **kwargs): | |||||
| if dataset.indexing_technique == "high_quality": | if dataset.indexing_technique == "high_quality": | ||||
| vector = Vector(dataset) | vector = Vector(dataset) | ||||
| vector.create(documents) | vector.create(documents) | ||||
| if with_keywords: | if with_keywords: | ||||
| keywords_list = kwargs.get("keywords_list") | |||||
| keyword = Keyword(dataset) | keyword = Keyword(dataset) | ||||
| keyword.create(documents) | |||||
| if keywords_list and len(keywords_list) > 0: | |||||
| keyword.add_texts(documents, keywords_list=keywords_list) | |||||
| else: | |||||
| keyword.add_texts(documents) | |||||
| def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True): | |||||
| def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True, **kwargs): | |||||
| if dataset.indexing_technique == "high_quality": | if dataset.indexing_technique == "high_quality": | ||||
| vector = Vector(dataset) | vector = Vector(dataset) | ||||
| if node_ids: | if node_ids: |
| """Paragraph index processor.""" | |||||
| import uuid | |||||
| from typing import Optional | |||||
| from core.model_manager import ModelInstance | |||||
| from core.rag.cleaner.clean_processor import CleanProcessor | |||||
| from core.rag.datasource.retrieval_service import RetrievalService | |||||
| from core.rag.datasource.vdb.vector_factory import Vector | |||||
| from core.rag.extractor.entity.extract_setting import ExtractSetting | |||||
| from core.rag.extractor.extract_processor import ExtractProcessor | |||||
| from core.rag.index_processor.index_processor_base import BaseIndexProcessor | |||||
| from core.rag.models.document import ChildDocument, Document | |||||
| from extensions.ext_database import db | |||||
| from libs import helper | |||||
| from models.dataset import ChildChunk, Dataset, DocumentSegment | |||||
| from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule | |||||
| class ParentChildIndexProcessor(BaseIndexProcessor): | |||||
| def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]: | |||||
| text_docs = ExtractProcessor.extract( | |||||
| extract_setting=extract_setting, | |||||
| is_automatic=( | |||||
| kwargs.get("process_rule_mode") == "automatic" or kwargs.get("process_rule_mode") == "hierarchical" | |||||
| ), | |||||
| ) | |||||
| return text_docs | |||||
| def transform(self, documents: list[Document], **kwargs) -> list[Document]: | |||||
| process_rule = kwargs.get("process_rule") | |||||
| rules = Rule(**process_rule.get("rules")) | |||||
| all_documents = [] | |||||
| if rules.parent_mode == ParentMode.PARAGRAPH: | |||||
| # Split the text documents into nodes. | |||||
| splitter = self._get_splitter( | |||||
| processing_rule_mode=process_rule.get("mode"), | |||||
| max_tokens=rules.segmentation.max_tokens, | |||||
| chunk_overlap=rules.segmentation.chunk_overlap, | |||||
| separator=rules.segmentation.separator, | |||||
| embedding_model_instance=kwargs.get("embedding_model_instance"), | |||||
| ) | |||||
| for document in documents: | |||||
| # document clean | |||||
| document_text = CleanProcessor.clean(document.page_content, process_rule) | |||||
| document.page_content = document_text | |||||
| # parse document to nodes | |||||
| document_nodes = splitter.split_documents([document]) | |||||
| split_documents = [] | |||||
| for document_node in document_nodes: | |||||
| if document_node.page_content.strip(): | |||||
| doc_id = str(uuid.uuid4()) | |||||
| hash = helper.generate_text_hash(document_node.page_content) | |||||
| document_node.metadata["doc_id"] = doc_id | |||||
| document_node.metadata["doc_hash"] = hash | |||||
| # delete Splitter character | |||||
| page_content = document_node.page_content | |||||
| if page_content.startswith(".") or page_content.startswith("。"): | |||||
| page_content = page_content[1:].strip() | |||||
| else: | |||||
| page_content = page_content | |||||
| if len(page_content) > 0: | |||||
| document_node.page_content = page_content | |||||
| # parse document to child nodes | |||||
| child_nodes = self._split_child_nodes( | |||||
| document_node, rules, process_rule.get("mode"), kwargs.get("embedding_model_instance") | |||||
| ) | |||||
| document_node.children = child_nodes | |||||
| split_documents.append(document_node) | |||||
| all_documents.extend(split_documents) | |||||
| elif rules.parent_mode == ParentMode.FULL_DOC: | |||||
| page_content = "\n".join([document.page_content for document in documents]) | |||||
| document = Document(page_content=page_content, metadata=documents[0].metadata) | |||||
| # parse document to child nodes | |||||
| child_nodes = self._split_child_nodes( | |||||
| document, rules, process_rule.get("mode"), kwargs.get("embedding_model_instance") | |||||
| ) | |||||
| document.children = child_nodes | |||||
| doc_id = str(uuid.uuid4()) | |||||
| hash = helper.generate_text_hash(document.page_content) | |||||
| document.metadata["doc_id"] = doc_id | |||||
| document.metadata["doc_hash"] = hash | |||||
| all_documents.append(document) | |||||
| return all_documents | |||||
| def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True, **kwargs): | |||||
| if dataset.indexing_technique == "high_quality": | |||||
| vector = Vector(dataset) | |||||
| for document in documents: | |||||
| child_documents = document.children | |||||
| if child_documents: | |||||
| formatted_child_documents = [ | |||||
| Document(**child_document.model_dump()) for child_document in child_documents | |||||
| ] | |||||
| vector.create(formatted_child_documents) | |||||
| def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True, **kwargs): | |||||
| # node_ids is segment's node_ids | |||||
| if dataset.indexing_technique == "high_quality": | |||||
| delete_child_chunks = kwargs.get("delete_child_chunks") or False | |||||
| vector = Vector(dataset) | |||||
| if node_ids: | |||||
| child_node_ids = ( | |||||
| db.session.query(ChildChunk.index_node_id) | |||||
| .join(DocumentSegment, ChildChunk.segment_id == DocumentSegment.id) | |||||
| .filter( | |||||
| DocumentSegment.dataset_id == dataset.id, | |||||
| DocumentSegment.index_node_id.in_(node_ids), | |||||
| ChildChunk.dataset_id == dataset.id, | |||||
| ) | |||||
| .all() | |||||
| ) | |||||
| child_node_ids = [child_node_id[0] for child_node_id in child_node_ids] | |||||
| vector.delete_by_ids(child_node_ids) | |||||
| if delete_child_chunks: | |||||
| db.session.query(ChildChunk).filter( | |||||
| ChildChunk.dataset_id == dataset.id, ChildChunk.index_node_id.in_(child_node_ids) | |||||
| ).delete() | |||||
| db.session.commit() | |||||
| else: | |||||
| vector.delete() | |||||
| if delete_child_chunks: | |||||
| db.session.query(ChildChunk).filter(ChildChunk.dataset_id == dataset.id).delete() | |||||
| db.session.commit() | |||||
| def retrieve( | |||||
| self, | |||||
| retrieval_method: str, | |||||
| query: str, | |||||
| dataset: Dataset, | |||||
| top_k: int, | |||||
| score_threshold: float, | |||||
| reranking_model: dict, | |||||
| ) -> list[Document]: | |||||
| # Set search parameters. | |||||
| results = RetrievalService.retrieve( | |||||
| retrieval_method=retrieval_method, | |||||
| dataset_id=dataset.id, | |||||
| query=query, | |||||
| top_k=top_k, | |||||
| score_threshold=score_threshold, | |||||
| reranking_model=reranking_model, | |||||
| ) | |||||
| # Organize results. | |||||
| docs = [] | |||||
| for result in results: | |||||
| metadata = result.metadata | |||||
| metadata["score"] = result.score | |||||
| if result.score > score_threshold: | |||||
| doc = Document(page_content=result.page_content, metadata=metadata) | |||||
| docs.append(doc) | |||||
| return docs | |||||
| def _split_child_nodes( | |||||
| self, | |||||
| document_node: Document, | |||||
| rules: Rule, | |||||
| process_rule_mode: str, | |||||
| embedding_model_instance: Optional[ModelInstance], | |||||
| ) -> list[ChildDocument]: | |||||
| child_splitter = self._get_splitter( | |||||
| processing_rule_mode=process_rule_mode, | |||||
| max_tokens=rules.subchunk_segmentation.max_tokens, | |||||
| chunk_overlap=rules.subchunk_segmentation.chunk_overlap, | |||||
| separator=rules.subchunk_segmentation.separator, | |||||
| embedding_model_instance=embedding_model_instance, | |||||
| ) | |||||
| # parse document to child nodes | |||||
| child_nodes = [] | |||||
| child_documents = child_splitter.split_documents([document_node]) | |||||
| for child_document_node in child_documents: | |||||
| if child_document_node.page_content.strip(): | |||||
| doc_id = str(uuid.uuid4()) | |||||
| hash = helper.generate_text_hash(child_document_node.page_content) | |||||
| child_document = ChildDocument( | |||||
| page_content=child_document_node.page_content, metadata=document_node.metadata | |||||
| ) | |||||
| child_document.metadata["doc_id"] = doc_id | |||||
| child_document.metadata["doc_hash"] = hash | |||||
| child_page_content = child_document.page_content | |||||
| if child_page_content.startswith(".") or child_page_content.startswith("。"): | |||||
| child_page_content = child_page_content[1:].strip() | |||||
| if len(child_page_content) > 0: | |||||
| child_document.page_content = child_page_content | |||||
| child_nodes.append(child_document) | |||||
| return child_nodes |
| from core.tools.utils.text_processing_utils import remove_leading_symbols | from core.tools.utils.text_processing_utils import remove_leading_symbols | ||||
| from libs import helper | from libs import helper | ||||
| from models.dataset import Dataset | from models.dataset import Dataset | ||||
| from services.entities.knowledge_entities.knowledge_entities import Rule | |||||
| class QAIndexProcessor(BaseIndexProcessor): | class QAIndexProcessor(BaseIndexProcessor): | ||||
| def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]: | def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]: | ||||
| text_docs = ExtractProcessor.extract( | text_docs = ExtractProcessor.extract( | ||||
| extract_setting=extract_setting, is_automatic=kwargs.get("process_rule_mode") == "automatic" | |||||
| extract_setting=extract_setting, | |||||
| is_automatic=( | |||||
| kwargs.get("process_rule_mode") == "automatic" or kwargs.get("process_rule_mode") == "hierarchical" | |||||
| ), | |||||
| ) | ) | ||||
| return text_docs | return text_docs | ||||
| def transform(self, documents: list[Document], **kwargs) -> list[Document]: | def transform(self, documents: list[Document], **kwargs) -> list[Document]: | ||||
| preview = kwargs.get("preview") | |||||
| process_rule = kwargs.get("process_rule") | |||||
| rules = Rule(**process_rule.get("rules")) | |||||
| splitter = self._get_splitter( | splitter = self._get_splitter( | ||||
| processing_rule=kwargs.get("process_rule") or {}, | |||||
| processing_rule_mode=process_rule.get("mode"), | |||||
| max_tokens=rules.segmentation.max_tokens, | |||||
| chunk_overlap=rules.segmentation.chunk_overlap, | |||||
| separator=rules.segmentation.separator, | |||||
| embedding_model_instance=kwargs.get("embedding_model_instance"), | embedding_model_instance=kwargs.get("embedding_model_instance"), | ||||
| ) | ) | ||||
| document_node.page_content = remove_leading_symbols(page_content) | document_node.page_content = remove_leading_symbols(page_content) | ||||
| split_documents.append(document_node) | split_documents.append(document_node) | ||||
| all_documents.extend(split_documents) | all_documents.extend(split_documents) | ||||
| for i in range(0, len(all_documents), 10): | |||||
| threads = [] | |||||
| sub_documents = all_documents[i : i + 10] | |||||
| for doc in sub_documents: | |||||
| document_format_thread = threading.Thread( | |||||
| target=self._format_qa_document, | |||||
| kwargs={ | |||||
| "flask_app": current_app._get_current_object(), # type: ignore | |||||
| "tenant_id": kwargs.get("tenant_id"), | |||||
| "document_node": doc, | |||||
| "all_qa_documents": all_qa_documents, | |||||
| "document_language": kwargs.get("doc_language", "English"), | |||||
| }, | |||||
| ) | |||||
| threads.append(document_format_thread) | |||||
| document_format_thread.start() | |||||
| for thread in threads: | |||||
| thread.join() | |||||
| if preview: | |||||
| self._format_qa_document( | |||||
| current_app._get_current_object(), | |||||
| kwargs.get("tenant_id"), | |||||
| all_documents[0], | |||||
| all_qa_documents, | |||||
| kwargs.get("doc_language", "English"), | |||||
| ) | |||||
| else: | |||||
| for i in range(0, len(all_documents), 10): | |||||
| threads = [] | |||||
| sub_documents = all_documents[i : i + 10] | |||||
| for doc in sub_documents: | |||||
| document_format_thread = threading.Thread( | |||||
| target=self._format_qa_document, | |||||
| kwargs={ | |||||
| "flask_app": current_app._get_current_object(), | |||||
| "tenant_id": kwargs.get("tenant_id"), | |||||
| "document_node": doc, | |||||
| "all_qa_documents": all_qa_documents, | |||||
| "document_language": kwargs.get("doc_language", "English"), | |||||
| }, | |||||
| ) | |||||
| threads.append(document_format_thread) | |||||
| document_format_thread.start() | |||||
| for thread in threads: | |||||
| thread.join() | |||||
| return all_qa_documents | return all_qa_documents | ||||
| def format_by_template(self, file: FileStorage, **kwargs) -> list[Document]: | def format_by_template(self, file: FileStorage, **kwargs) -> list[Document]: | ||||
| raise ValueError(str(e)) | raise ValueError(str(e)) | ||||
| return text_docs | return text_docs | ||||
| def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True): | |||||
| def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True, **kwargs): | |||||
| if dataset.indexing_technique == "high_quality": | if dataset.indexing_technique == "high_quality": | ||||
| vector = Vector(dataset) | vector = Vector(dataset) | ||||
| vector.create(documents) | vector.create(documents) | ||||
| def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True): | |||||
| def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True, **kwargs): | |||||
| vector = Vector(dataset) | vector = Vector(dataset) | ||||
| if node_ids: | if node_ids: | ||||
| vector.delete_by_ids(node_ids) | vector.delete_by_ids(node_ids) |
| from pydantic import BaseModel, Field | from pydantic import BaseModel, Field | ||||
| class ChildDocument(BaseModel): | |||||
| """Class for storing a piece of text and associated metadata.""" | |||||
| page_content: str | |||||
| vector: Optional[list[float]] = None | |||||
| """Arbitrary metadata about the page content (e.g., source, relationships to other | |||||
| documents, etc.). | |||||
| """ | |||||
| metadata: Optional[dict] = Field(default_factory=dict) | |||||
| class Document(BaseModel): | class Document(BaseModel): | ||||
| """Class for storing a piece of text and associated metadata.""" | """Class for storing a piece of text and associated metadata.""" | ||||
| provider: Optional[str] = "dify" | provider: Optional[str] = "dify" | ||||
| children: Optional[list[ChildDocument]] = None | |||||
| class BaseDocumentTransformer(ABC): | class BaseDocumentTransformer(ABC): | ||||
| """Abstract base class for document transformation systems. | """Abstract base class for document transformation systems. |
| "content": item.page_content, | "content": item.page_content, | ||||
| } | } | ||||
| retrieval_resource_list.append(source) | retrieval_resource_list.append(source) | ||||
| document_score_list = {} | |||||
| # deal with dify documents | # deal with dify documents | ||||
| if dify_documents: | if dify_documents: | ||||
| for item in dify_documents: | |||||
| if item.metadata.get("score"): | |||||
| document_score_list[item.metadata["doc_id"]] = item.metadata["score"] | |||||
| index_node_ids = [document.metadata["doc_id"] for document in dify_documents] | |||||
| segments = DocumentSegment.query.filter( | |||||
| DocumentSegment.dataset_id.in_(dataset_ids), | |||||
| DocumentSegment.status == "completed", | |||||
| DocumentSegment.enabled == True, | |||||
| DocumentSegment.index_node_id.in_(index_node_ids), | |||||
| ).all() | |||||
| if segments: | |||||
| index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)} | |||||
| sorted_segments = sorted( | |||||
| segments, key=lambda segment: index_node_id_to_position.get(segment.index_node_id, float("inf")) | |||||
| ) | |||||
| for segment in sorted_segments: | |||||
| records = RetrievalService.format_retrieval_documents(dify_documents) | |||||
| if records: | |||||
| for record in records: | |||||
| segment = record.segment | |||||
| if segment.answer: | if segment.answer: | ||||
| document_context_list.append( | document_context_list.append( | ||||
| DocumentContext( | DocumentContext( | ||||
| content=f"question:{segment.get_sign_content()} answer:{segment.answer}", | content=f"question:{segment.get_sign_content()} answer:{segment.answer}", | ||||
| score=document_score_list.get(segment.index_node_id, None), | |||||
| score=record.score, | |||||
| ) | ) | ||||
| ) | ) | ||||
| else: | else: | ||||
| document_context_list.append( | document_context_list.append( | ||||
| DocumentContext( | DocumentContext( | ||||
| content=segment.get_sign_content(), | content=segment.get_sign_content(), | ||||
| score=document_score_list.get(segment.index_node_id, None), | |||||
| score=record.score, | |||||
| ) | ) | ||||
| ) | ) | ||||
| if show_retrieve_source: | if show_retrieve_source: | ||||
| for segment in sorted_segments: | |||||
| for record in records: | |||||
| segment = record.segment | |||||
| dataset = Dataset.query.filter_by(id=segment.dataset_id).first() | dataset = Dataset.query.filter_by(id=segment.dataset_id).first() | ||||
| document = DatasetDocument.query.filter( | document = DatasetDocument.query.filter( | ||||
| DatasetDocument.id == segment.document_id, | DatasetDocument.id == segment.document_id, | ||||
| "data_source_type": document.data_source_type, | "data_source_type": document.data_source_type, | ||||
| "segment_id": segment.id, | "segment_id": segment.id, | ||||
| "retriever_from": invoke_from.to_source(), | "retriever_from": invoke_from.to_source(), | ||||
| "score": document_score_list.get(segment.index_node_id, 0.0), | |||||
| "score": record.score or 0.0, | |||||
| } | } | ||||
| if invoke_from.to_source() == "dev": | if invoke_from.to_source() == "dev": |
| from core.model_manager import ModelInstance, ModelManager | from core.model_manager import ModelInstance, ModelManager | ||||
| from core.model_runtime.entities.model_entities import ModelFeature, ModelType | from core.model_runtime.entities.model_entities import ModelFeature, ModelType | ||||
| from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel | from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel | ||||
| from core.rag.datasource.retrieval_service import RetrievalService | |||||
| from core.rag.retrieval.dataset_retrieval import DatasetRetrieval | from core.rag.retrieval.dataset_retrieval import DatasetRetrieval | ||||
| from core.rag.retrieval.retrieval_methods import RetrievalMethod | from core.rag.retrieval.retrieval_methods import RetrievalMethod | ||||
| from core.variables import StringSegment | from core.variables import StringSegment | ||||
| from core.workflow.nodes.base import BaseNode | from core.workflow.nodes.base import BaseNode | ||||
| from core.workflow.nodes.enums import NodeType | from core.workflow.nodes.enums import NodeType | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from models.dataset import Dataset, Document, DocumentSegment | |||||
| from models.dataset import Dataset, Document | |||||
| from models.workflow import WorkflowNodeExecutionStatus | from models.workflow import WorkflowNodeExecutionStatus | ||||
| from .entities import KnowledgeRetrievalNodeData | from .entities import KnowledgeRetrievalNodeData | ||||
| "content": item.page_content, | "content": item.page_content, | ||||
| } | } | ||||
| retrieval_resource_list.append(source) | retrieval_resource_list.append(source) | ||||
| document_score_list: dict[str, float] = {} | |||||
| # deal with dify documents | # deal with dify documents | ||||
| if dify_documents: | if dify_documents: | ||||
| document_score_list = {} | |||||
| for item in dify_documents: | |||||
| if item.metadata.get("score"): | |||||
| document_score_list[item.metadata["doc_id"]] = item.metadata["score"] | |||||
| index_node_ids = [document.metadata["doc_id"] for document in dify_documents] | |||||
| segments = DocumentSegment.query.filter( | |||||
| DocumentSegment.dataset_id.in_(dataset_ids), | |||||
| DocumentSegment.completed_at.isnot(None), | |||||
| DocumentSegment.status == "completed", | |||||
| DocumentSegment.enabled == True, | |||||
| DocumentSegment.index_node_id.in_(index_node_ids), | |||||
| ).all() | |||||
| if segments: | |||||
| index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)} | |||||
| sorted_segments = sorted( | |||||
| segments, key=lambda segment: index_node_id_to_position.get(segment.index_node_id, float("inf")) | |||||
| ) | |||||
| for segment in sorted_segments: | |||||
| records = RetrievalService.format_retrieval_documents(dify_documents) | |||||
| if records: | |||||
| for record in records: | |||||
| segment = record.segment | |||||
| dataset = Dataset.query.filter_by(id=segment.dataset_id).first() | dataset = Dataset.query.filter_by(id=segment.dataset_id).first() | ||||
| document = Document.query.filter( | document = Document.query.filter( | ||||
| Document.id == segment.document_id, | Document.id == segment.document_id, | ||||
| "document_data_source_type": document.data_source_type, | "document_data_source_type": document.data_source_type, | ||||
| "segment_id": segment.id, | "segment_id": segment.id, | ||||
| "retriever_from": "workflow", | "retriever_from": "workflow", | ||||
| "score": document_score_list.get(segment.index_node_id, None), | |||||
| "score": record.score or 0.0, | |||||
| "segment_hit_count": segment.hit_count, | "segment_hit_count": segment.hit_count, | ||||
| "segment_word_count": segment.word_count, | "segment_word_count": segment.word_count, | ||||
| "segment_position": segment.position, | "segment_position": segment.position, | ||||
| key=lambda x: x["metadata"]["score"] if x["metadata"].get("score") is not None else 0.0, | key=lambda x: x["metadata"]["score"] if x["metadata"].get("score") is not None else 0.0, | ||||
| reverse=True, | reverse=True, | ||||
| ) | ) | ||||
| position = 1 | |||||
| for item in retrieval_resource_list: | |||||
| for position, item in enumerate(retrieval_resource_list, start=1): | |||||
| item["metadata"]["position"] = position | item["metadata"]["position"] = position | ||||
| position += 1 | |||||
| return retrieval_resource_list | return retrieval_resource_list | ||||
| @classmethod | @classmethod |
| "embedding_available": fields.Boolean, | "embedding_available": fields.Boolean, | ||||
| "retrieval_model_dict": fields.Nested(dataset_retrieval_model_fields), | "retrieval_model_dict": fields.Nested(dataset_retrieval_model_fields), | ||||
| "tags": fields.List(fields.Nested(tag_fields)), | "tags": fields.List(fields.Nested(tag_fields)), | ||||
| "doc_form": fields.String, | |||||
| "external_knowledge_info": fields.Nested(external_knowledge_info_fields), | "external_knowledge_info": fields.Nested(external_knowledge_info_fields), | ||||
| "external_retrieval_model": fields.Nested(external_retrieval_model_fields, allow_null=True), | "external_retrieval_model": fields.Nested(external_retrieval_model_fields, allow_null=True), | ||||
| } | } |
| "data_source_info": fields.Raw(attribute="data_source_info_dict"), | "data_source_info": fields.Raw(attribute="data_source_info_dict"), | ||||
| "data_source_detail_dict": fields.Raw(attribute="data_source_detail_dict"), | "data_source_detail_dict": fields.Raw(attribute="data_source_detail_dict"), | ||||
| "dataset_process_rule_id": fields.String, | "dataset_process_rule_id": fields.String, | ||||
| "process_rule_dict": fields.Raw(attribute="process_rule_dict"), | |||||
| "name": fields.String, | "name": fields.String, | ||||
| "created_from": fields.String, | "created_from": fields.String, | ||||
| "created_by": fields.String, | "created_by": fields.String, |
| "document": fields.Nested(document_fields), | "document": fields.Nested(document_fields), | ||||
| } | } | ||||
| child_chunk_fields = { | |||||
| "id": fields.String, | |||||
| "content": fields.String, | |||||
| "position": fields.Integer, | |||||
| "score": fields.Float, | |||||
| } | |||||
| hit_testing_record_fields = { | hit_testing_record_fields = { | ||||
| "segment": fields.Nested(segment_fields), | "segment": fields.Nested(segment_fields), | ||||
| "child_chunks": fields.List(fields.Nested(child_chunk_fields)), | |||||
| "score": fields.Float, | "score": fields.Float, | ||||
| "tsne_position": fields.Raw, | "tsne_position": fields.Raw, | ||||
| } | } |
| from libs.helper import TimestampField | from libs.helper import TimestampField | ||||
| child_chunk_fields = { | |||||
| "id": fields.String, | |||||
| "segment_id": fields.String, | |||||
| "content": fields.String, | |||||
| "position": fields.Integer, | |||||
| "word_count": fields.Integer, | |||||
| "type": fields.String, | |||||
| "created_at": TimestampField, | |||||
| "updated_at": TimestampField, | |||||
| } | |||||
| segment_fields = { | segment_fields = { | ||||
| "id": fields.String, | "id": fields.String, | ||||
| "position": fields.Integer, | "position": fields.Integer, | ||||
| "status": fields.String, | "status": fields.String, | ||||
| "created_by": fields.String, | "created_by": fields.String, | ||||
| "created_at": TimestampField, | "created_at": TimestampField, | ||||
| "updated_at": TimestampField, | |||||
| "updated_by": fields.String, | |||||
| "indexing_at": TimestampField, | "indexing_at": TimestampField, | ||||
| "completed_at": TimestampField, | "completed_at": TimestampField, | ||||
| "error": fields.String, | "error": fields.String, | ||||
| "stopped_at": TimestampField, | "stopped_at": TimestampField, | ||||
| "child_chunks": fields.List(fields.Nested(child_chunk_fields)), | |||||
| } | } | ||||
| segment_list_response = { | segment_list_response = { |
| """parent-child-index | |||||
| Revision ID: e19037032219 | |||||
| Revises: 01d6889832f7 | |||||
| Create Date: 2024-11-22 07:01:17.550037 | |||||
| """ | |||||
| from alembic import op | |||||
| import models as models | |||||
| import sqlalchemy as sa | |||||
| # revision identifiers, used by Alembic. | |||||
| revision = 'e19037032219' | |||||
| down_revision = 'd7999dfa4aae' | |||||
| branch_labels = None | |||||
| depends_on = None | |||||
| def upgrade(): | |||||
| # ### commands auto generated by Alembic - please adjust! ### | |||||
| op.create_table('child_chunks', | |||||
| sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), | |||||
| sa.Column('tenant_id', models.types.StringUUID(), nullable=False), | |||||
| sa.Column('dataset_id', models.types.StringUUID(), nullable=False), | |||||
| sa.Column('document_id', models.types.StringUUID(), nullable=False), | |||||
| sa.Column('segment_id', models.types.StringUUID(), nullable=False), | |||||
| sa.Column('position', sa.Integer(), nullable=False), | |||||
| sa.Column('content', sa.Text(), nullable=False), | |||||
| sa.Column('word_count', sa.Integer(), nullable=False), | |||||
| sa.Column('index_node_id', sa.String(length=255), nullable=True), | |||||
| sa.Column('index_node_hash', sa.String(length=255), nullable=True), | |||||
| sa.Column('type', sa.String(length=255), server_default=sa.text("'automatic'::character varying"), nullable=False), | |||||
| sa.Column('created_by', models.types.StringUUID(), nullable=False), | |||||
| sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), | |||||
| sa.Column('updated_by', models.types.StringUUID(), nullable=True), | |||||
| sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), | |||||
| sa.Column('indexing_at', sa.DateTime(), nullable=True), | |||||
| sa.Column('completed_at', sa.DateTime(), nullable=True), | |||||
| sa.Column('error', sa.Text(), nullable=True), | |||||
| sa.PrimaryKeyConstraint('id', name='child_chunk_pkey') | |||||
| ) | |||||
| with op.batch_alter_table('child_chunks', schema=None) as batch_op: | |||||
| batch_op.create_index('child_chunk_dataset_id_idx', ['tenant_id', 'dataset_id', 'document_id', 'segment_id', 'index_node_id'], unique=False) | |||||
| # ### end Alembic commands ### | |||||
| def downgrade(): | |||||
| # ### commands auto generated by Alembic - please adjust! ### | |||||
| with op.batch_alter_table('child_chunks', schema=None) as batch_op: | |||||
| batch_op.drop_index('child_chunk_dataset_id_idx') | |||||
| op.drop_table('child_chunks') | |||||
| # ### end Alembic commands ### |
| """add_auto_disabled_dataset_logs | |||||
| Revision ID: 923752d42eb6 | |||||
| Revises: e19037032219 | |||||
| Create Date: 2024-12-25 11:37:55.467101 | |||||
| """ | |||||
| from alembic import op | |||||
| import models as models | |||||
| import sqlalchemy as sa | |||||
| from sqlalchemy.dialects import postgresql | |||||
| # revision identifiers, used by Alembic. | |||||
| revision = '923752d42eb6' | |||||
| down_revision = 'e19037032219' | |||||
| branch_labels = None | |||||
| depends_on = None | |||||
| def upgrade(): | |||||
| # ### commands auto generated by Alembic - please adjust! ### | |||||
| op.create_table('dataset_auto_disable_logs', | |||||
| sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), | |||||
| sa.Column('tenant_id', models.types.StringUUID(), nullable=False), | |||||
| sa.Column('dataset_id', models.types.StringUUID(), nullable=False), | |||||
| sa.Column('document_id', models.types.StringUUID(), nullable=False), | |||||
| sa.Column('notified', sa.Boolean(), server_default=sa.text('false'), nullable=False), | |||||
| sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), | |||||
| sa.PrimaryKeyConstraint('id', name='dataset_auto_disable_log_pkey') | |||||
| ) | |||||
| with op.batch_alter_table('dataset_auto_disable_logs', schema=None) as batch_op: | |||||
| batch_op.create_index('dataset_auto_disable_log_created_atx', ['created_at'], unique=False) | |||||
| batch_op.create_index('dataset_auto_disable_log_dataset_idx', ['dataset_id'], unique=False) | |||||
| batch_op.create_index('dataset_auto_disable_log_tenant_idx', ['tenant_id'], unique=False) | |||||
| # ### end Alembic commands ### | |||||
| def downgrade(): | |||||
| # ### commands auto generated by Alembic - please adjust! ### | |||||
| with op.batch_alter_table('dataset_auto_disable_logs', schema=None) as batch_op: | |||||
| batch_op.drop_index('dataset_auto_disable_log_tenant_idx') | |||||
| batch_op.drop_index('dataset_auto_disable_log_dataset_idx') | |||||
| batch_op.drop_index('dataset_auto_disable_log_created_atx') | |||||
| op.drop_table('dataset_auto_disable_logs') | |||||
| # ### end Alembic commands ### |
| from configs import dify_config | from configs import dify_config | ||||
| from core.rag.retrieval.retrieval_methods import RetrievalMethod | from core.rag.retrieval.retrieval_methods import RetrievalMethod | ||||
| from extensions.ext_storage import storage | from extensions.ext_storage import storage | ||||
| from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule | |||||
| from .account import Account | from .account import Account | ||||
| from .engine import db | from .engine import db | ||||
| created_by = db.Column(StringUUID, nullable=False) | created_by = db.Column(StringUUID, nullable=False) | ||||
| created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | ||||
| MODES = ["automatic", "custom"] | |||||
| MODES = ["automatic", "custom", "hierarchical"] | |||||
| PRE_PROCESSING_RULES = ["remove_stopwords", "remove_extra_spaces", "remove_urls_emails"] | PRE_PROCESSING_RULES = ["remove_stopwords", "remove_extra_spaces", "remove_urls_emails"] | ||||
| AUTOMATIC_RULES: dict[str, Any] = { | AUTOMATIC_RULES: dict[str, Any] = { | ||||
| "pre_processing_rules": [ | "pre_processing_rules": [ | ||||
| "dataset_id": self.dataset_id, | "dataset_id": self.dataset_id, | ||||
| "mode": self.mode, | "mode": self.mode, | ||||
| "rules": self.rules_dict, | "rules": self.rules_dict, | ||||
| "created_by": self.created_by, | |||||
| "created_at": self.created_at, | |||||
| } | } | ||||
| @property | @property | ||||
| .scalar() | .scalar() | ||||
| ) | ) | ||||
| @property | |||||
| def process_rule_dict(self): | |||||
| if self.dataset_process_rule_id: | |||||
| return self.dataset_process_rule.to_dict() | |||||
| return None | |||||
| def to_dict(self): | def to_dict(self): | ||||
| return { | return { | ||||
| "id": self.id, | "id": self.id, | ||||
| .first() | .first() | ||||
| ) | ) | ||||
| @property | |||||
| def child_chunks(self): | |||||
| process_rule = self.document.dataset_process_rule | |||||
| if process_rule.mode == "hierarchical": | |||||
| rules = Rule(**process_rule.rules_dict) | |||||
| if rules.parent_mode and rules.parent_mode != ParentMode.FULL_DOC: | |||||
| child_chunks = ( | |||||
| db.session.query(ChildChunk) | |||||
| .filter(ChildChunk.segment_id == self.id) | |||||
| .order_by(ChildChunk.position.asc()) | |||||
| .all() | |||||
| ) | |||||
| return child_chunks or [] | |||||
| else: | |||||
| return [] | |||||
| else: | |||||
| return [] | |||||
| def get_sign_content(self): | def get_sign_content(self): | ||||
| signed_urls = [] | signed_urls = [] | ||||
| text = self.content | text = self.content | ||||
| return text | return text | ||||
| class ChildChunk(db.Model): | |||||
| __tablename__ = "child_chunks" | |||||
| __table_args__ = ( | |||||
| db.PrimaryKeyConstraint("id", name="child_chunk_pkey"), | |||||
| db.Index("child_chunk_dataset_id_idx", "tenant_id", "dataset_id", "document_id", "segment_id", "index_node_id"), | |||||
| ) | |||||
| # initial fields | |||||
| id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) | |||||
| tenant_id = db.Column(StringUUID, nullable=False) | |||||
| dataset_id = db.Column(StringUUID, nullable=False) | |||||
| document_id = db.Column(StringUUID, nullable=False) | |||||
| segment_id = db.Column(StringUUID, nullable=False) | |||||
| position = db.Column(db.Integer, nullable=False) | |||||
| content = db.Column(db.Text, nullable=False) | |||||
| word_count = db.Column(db.Integer, nullable=False) | |||||
| # indexing fields | |||||
| index_node_id = db.Column(db.String(255), nullable=True) | |||||
| index_node_hash = db.Column(db.String(255), nullable=True) | |||||
| type = db.Column(db.String(255), nullable=False, server_default=db.text("'automatic'::character varying")) | |||||
| created_by = db.Column(StringUUID, nullable=False) | |||||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |||||
| updated_by = db.Column(StringUUID, nullable=True) | |||||
| updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |||||
| indexing_at = db.Column(db.DateTime, nullable=True) | |||||
| completed_at = db.Column(db.DateTime, nullable=True) | |||||
| error = db.Column(db.Text, nullable=True) | |||||
| @property | |||||
| def dataset(self): | |||||
| return db.session.query(Dataset).filter(Dataset.id == self.dataset_id).first() | |||||
| @property | |||||
| def document(self): | |||||
| return db.session.query(Document).filter(Document.id == self.document_id).first() | |||||
| @property | |||||
| def segment(self): | |||||
| return db.session.query(DocumentSegment).filter(DocumentSegment.id == self.segment_id).first() | |||||
| class AppDatasetJoin(db.Model): # type: ignore[name-defined] | class AppDatasetJoin(db.Model): # type: ignore[name-defined] | ||||
| __tablename__ = "app_dataset_joins" | __tablename__ = "app_dataset_joins" | ||||
| __table_args__ = ( | __table_args__ = ( | ||||
| created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | ||||
| updated_by = db.Column(StringUUID, nullable=True) | updated_by = db.Column(StringUUID, nullable=True) | ||||
| updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | ||||
| class DatasetAutoDisableLog(db.Model): | |||||
| __tablename__ = "dataset_auto_disable_logs" | |||||
| __table_args__ = ( | |||||
| db.PrimaryKeyConstraint("id", name="dataset_auto_disable_log_pkey"), | |||||
| db.Index("dataset_auto_disable_log_tenant_idx", "tenant_id"), | |||||
| db.Index("dataset_auto_disable_log_dataset_idx", "dataset_id"), | |||||
| db.Index("dataset_auto_disable_log_created_atx", "created_at"), | |||||
| ) | |||||
| id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) | |||||
| tenant_id = db.Column(StringUUID, nullable=False) | |||||
| dataset_id = db.Column(StringUUID, nullable=False) | |||||
| document_id = db.Column(StringUUID, nullable=False) | |||||
| notified = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) | |||||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) |
| from core.rag.index_processor.index_processor_factory import IndexProcessorFactory | from core.rag.index_processor.index_processor_factory import IndexProcessorFactory | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from extensions.ext_redis import redis_client | from extensions.ext_redis import redis_client | ||||
| from models.dataset import Dataset, DatasetQuery, Document | |||||
| from models.dataset import Dataset, DatasetAutoDisableLog, DatasetQuery, Document | |||||
| from services.feature_service import FeatureService | from services.feature_service import FeatureService | ||||
| ) | ) | ||||
| if not dataset_query or len(dataset_query) == 0: | if not dataset_query or len(dataset_query) == 0: | ||||
| try: | try: | ||||
| # add auto disable log | |||||
| documents = ( | |||||
| db.session.query(Document) | |||||
| .filter( | |||||
| Document.dataset_id == dataset.id, | |||||
| Document.enabled == True, | |||||
| Document.archived == False, | |||||
| ) | |||||
| .all() | |||||
| ) | |||||
| for document in documents: | |||||
| dataset_auto_disable_log = DatasetAutoDisableLog( | |||||
| tenant_id=dataset.tenant_id, | |||||
| dataset_id=dataset.id, | |||||
| document_id=document.id, | |||||
| ) | |||||
| db.session.add(dataset_auto_disable_log) | |||||
| # remove index | # remove index | ||||
| index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor() | index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor() | ||||
| index_processor.clean(dataset, None) | index_processor.clean(dataset, None) | ||||
| else: | else: | ||||
| plan = plan_cache.decode() | plan = plan_cache.decode() | ||||
| if plan == "sandbox": | if plan == "sandbox": | ||||
| # add auto disable log | |||||
| documents = ( | |||||
| db.session.query(Document) | |||||
| .filter( | |||||
| Document.dataset_id == dataset.id, | |||||
| Document.enabled == True, | |||||
| Document.archived == False, | |||||
| ) | |||||
| .all() | |||||
| ) | |||||
| for document in documents: | |||||
| dataset_auto_disable_log = DatasetAutoDisableLog( | |||||
| tenant_id=dataset.tenant_id, | |||||
| dataset_id=dataset.id, | |||||
| document_id=document.id, | |||||
| ) | |||||
| db.session.add(dataset_auto_disable_log) | |||||
| # remove index | # remove index | ||||
| index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor() | index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor() | ||||
| index_processor.clean(dataset, None) | index_processor.clean(dataset, None) |
| import logging | |||||
| import time | |||||
| import click | |||||
| from celery import shared_task | |||||
| from flask import render_template | |||||
| from extensions.ext_mail import mail | |||||
| from models.account import Account, Tenant, TenantAccountJoin | |||||
| from models.dataset import Dataset, DatasetAutoDisableLog | |||||
| @shared_task(queue="mail") | |||||
| def send_document_clean_notify_task(): | |||||
| """ | |||||
| Async Send document clean notify mail | |||||
| Usage: send_document_clean_notify_task.delay() | |||||
| """ | |||||
| if not mail.is_inited(): | |||||
| return | |||||
| logging.info(click.style("Start send document clean notify mail", fg="green")) | |||||
| start_at = time.perf_counter() | |||||
| # send document clean notify mail | |||||
| try: | |||||
| dataset_auto_disable_logs = DatasetAutoDisableLog.query.filter(DatasetAutoDisableLog.notified == False).all() | |||||
| # group by tenant_id | |||||
| dataset_auto_disable_logs_map = {} | |||||
| for dataset_auto_disable_log in dataset_auto_disable_logs: | |||||
| dataset_auto_disable_logs_map[dataset_auto_disable_log.tenant_id].append(dataset_auto_disable_log) | |||||
| for tenant_id, tenant_dataset_auto_disable_logs in dataset_auto_disable_logs_map.items(): | |||||
| knowledge_details = [] | |||||
| tenant = Tenant.query.filter(Tenant.id == tenant_id).first() | |||||
| if not tenant: | |||||
| continue | |||||
| current_owner_join = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, role="owner").first() | |||||
| account = Account.query.filter(Account.id == current_owner_join.account_id).first() | |||||
| if not account: | |||||
| continue | |||||
| dataset_auto_dataset_map = {} | |||||
| for dataset_auto_disable_log in tenant_dataset_auto_disable_logs: | |||||
| dataset_auto_dataset_map[dataset_auto_disable_log.dataset_id].append( | |||||
| dataset_auto_disable_log.document_id | |||||
| ) | |||||
| for dataset_id, document_ids in dataset_auto_dataset_map.items(): | |||||
| dataset = Dataset.query.filter(Dataset.id == dataset_id).first() | |||||
| if dataset: | |||||
| document_count = len(document_ids) | |||||
| knowledge_details.append(f"<li>Knowledge base {dataset.name}: {document_count} documents</li>") | |||||
| html_content = render_template( | |||||
| "clean_document_job_mail_template-US.html", | |||||
| ) | |||||
| mail.send(to=to, subject="立即加入 Dify 工作空间", html=html_content) | |||||
| end_at = time.perf_counter() | |||||
| logging.info( | |||||
| click.style("Send document clean notify mail succeeded: latency: {}".format(end_at - start_at), fg="green") | |||||
| ) | |||||
| except Exception: | |||||
| logging.exception("Send invite member mail to {} failed".format(to)) |
| from typing import Optional | |||||
| from enum import Enum | |||||
| from typing import Literal, Optional | |||||
| from pydantic import BaseModel | from pydantic import BaseModel | ||||
| answer: Optional[str] = None | answer: Optional[str] = None | ||||
| keywords: Optional[list[str]] = None | keywords: Optional[list[str]] = None | ||||
| enabled: Optional[bool] = None | enabled: Optional[bool] = None | ||||
| class ParentMode(str, Enum): | |||||
| FULL_DOC = "full-doc" | |||||
| PARAGRAPH = "paragraph" | |||||
| class NotionIcon(BaseModel): | |||||
| type: str | |||||
| url: Optional[str] = None | |||||
| emoji: Optional[str] = None | |||||
| class NotionPage(BaseModel): | |||||
| page_id: str | |||||
| page_name: str | |||||
| page_icon: Optional[NotionIcon] = None | |||||
| type: str | |||||
| class NotionInfo(BaseModel): | |||||
| workspace_id: str | |||||
| pages: list[NotionPage] | |||||
| class WebsiteInfo(BaseModel): | |||||
| provider: str | |||||
| job_id: str | |||||
| urls: list[str] | |||||
| only_main_content: bool = True | |||||
| class FileInfo(BaseModel): | |||||
| file_ids: list[str] | |||||
| class InfoList(BaseModel): | |||||
| data_source_type: Literal["upload_file", "notion_import", "website_crawl"] | |||||
| notion_info_list: Optional[list[NotionInfo]] = None | |||||
| file_info_list: Optional[FileInfo] = None | |||||
| website_info_list: Optional[WebsiteInfo] = None | |||||
| class DataSource(BaseModel): | |||||
| info_list: InfoList | |||||
| class PreProcessingRule(BaseModel): | |||||
| id: str | |||||
| enabled: bool | |||||
| class Segmentation(BaseModel): | |||||
| separator: str = "\n" | |||||
| max_tokens: int | |||||
| chunk_overlap: int = 0 | |||||
| class Rule(BaseModel): | |||||
| pre_processing_rules: Optional[list[PreProcessingRule]] = None | |||||
| segmentation: Optional[Segmentation] = None | |||||
| parent_mode: Optional[Literal["full-doc", "paragraph"]] = None | |||||
| subchunk_segmentation: Optional[Segmentation] = None | |||||
| class ProcessRule(BaseModel): | |||||
| mode: Literal["automatic", "custom", "hierarchical"] | |||||
| rules: Optional[Rule] = None | |||||
| class RerankingModel(BaseModel): | |||||
| reranking_provider_name: Optional[str] = None | |||||
| reranking_model_name: Optional[str] = None | |||||
| class RetrievalModel(BaseModel): | |||||
| search_method: Literal["hybrid_search", "semantic_search", "full_text_search"] | |||||
| reranking_enable: bool | |||||
| reranking_model: Optional[RerankingModel] = None | |||||
| top_k: int | |||||
| score_threshold_enabled: bool | |||||
| score_threshold: Optional[float] = None | |||||
| class KnowledgeConfig(BaseModel): | |||||
| original_document_id: Optional[str] = None | |||||
| duplicate: bool = True | |||||
| indexing_technique: Literal["high_quality", "economy"] | |||||
| data_source: Optional[DataSource] = None | |||||
| process_rule: Optional[ProcessRule] = None | |||||
| retrieval_model: Optional[RetrievalModel] = None | |||||
| doc_form: str = "text_model" | |||||
| doc_language: str = "English" | |||||
| embedding_model: Optional[str] = None | |||||
| embedding_model_provider: Optional[str] = None | |||||
| name: Optional[str] = None | |||||
| class SegmentUpdateArgs(BaseModel): | |||||
| content: Optional[str] = None | |||||
| answer: Optional[str] = None | |||||
| keywords: Optional[list[str]] = None | |||||
| regenerate_child_chunks: bool = False | |||||
| enabled: Optional[bool] = None | |||||
| class ChildChunkUpdateArgs(BaseModel): | |||||
| id: Optional[str] = None | |||||
| content: str |
| from services.errors.base import BaseServiceError | |||||
| class ChildChunkIndexingError(BaseServiceError): | |||||
| description = "{message}" | |||||
| class ChildChunkDeleteIndexError(BaseServiceError): | |||||
| description = "{message}" |
| from core.rag.retrieval.retrieval_methods import RetrievalMethod | from core.rag.retrieval.retrieval_methods import RetrievalMethod | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from models.account import Account | from models.account import Account | ||||
| from models.dataset import Dataset, DatasetQuery, DocumentSegment | |||||
| from models.dataset import Dataset, DatasetQuery | |||||
| default_retrieval_model = { | default_retrieval_model = { | ||||
| "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, | "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, | ||||
| db.session.add(dataset_query) | db.session.add(dataset_query) | ||||
| db.session.commit() | db.session.commit() | ||||
| return dict(cls.compact_retrieve_response(dataset, query, all_documents)) | |||||
| return cls.compact_retrieve_response(query, all_documents) | |||||
| @classmethod | @classmethod | ||||
| def external_retrieve( | def external_retrieve( | ||||
| return dict(cls.compact_external_retrieve_response(dataset, query, all_documents)) | return dict(cls.compact_external_retrieve_response(dataset, query, all_documents)) | ||||
| @classmethod | @classmethod | ||||
| def compact_retrieve_response(cls, dataset: Dataset, query: str, documents: list[Document]): | |||||
| records = [] | |||||
| for document in documents: | |||||
| if document.metadata is None: | |||||
| continue | |||||
| index_node_id = document.metadata["doc_id"] | |||||
| segment = ( | |||||
| db.session.query(DocumentSegment) | |||||
| .filter( | |||||
| DocumentSegment.dataset_id == dataset.id, | |||||
| DocumentSegment.enabled == True, | |||||
| DocumentSegment.status == "completed", | |||||
| DocumentSegment.index_node_id == index_node_id, | |||||
| ) | |||||
| .first() | |||||
| ) | |||||
| if not segment: | |||||
| continue | |||||
| record = { | |||||
| "segment": segment, | |||||
| "score": document.metadata.get("score", None), | |||||
| } | |||||
| records.append(record) | |||||
| def compact_retrieve_response(cls, query: str, documents: list[Document]): | |||||
| records = RetrievalService.format_retrieval_documents(documents) | |||||
| return { | return { | ||||
| "query": { | "query": { | ||||
| "content": query, | "content": query, | ||||
| }, | }, | ||||
| "records": records, | |||||
| "records": [record.model_dump() for record in records], | |||||
| } | } | ||||
| @classmethod | @classmethod |
| from typing import Optional | from typing import Optional | ||||
| from core.model_manager import ModelInstance, ModelManager | |||||
| from core.model_runtime.entities.model_entities import ModelType | |||||
| from core.rag.datasource.keyword.keyword_factory import Keyword | from core.rag.datasource.keyword.keyword_factory import Keyword | ||||
| from core.rag.datasource.vdb.vector_factory import Vector | from core.rag.datasource.vdb.vector_factory import Vector | ||||
| from core.rag.index_processor.constant.index_type import IndexType | |||||
| from core.rag.index_processor.index_processor_factory import IndexProcessorFactory | |||||
| from core.rag.models.document import Document | from core.rag.models.document import Document | ||||
| from models.dataset import Dataset, DocumentSegment | |||||
| from extensions.ext_database import db | |||||
| from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment | |||||
| from models.dataset import Document as DatasetDocument | |||||
| from services.entities.knowledge_entities.knowledge_entities import ParentMode | |||||
| class VectorService: | class VectorService: | ||||
| @classmethod | @classmethod | ||||
| def create_segments_vector( | def create_segments_vector( | ||||
| cls, keywords_list: Optional[list[list[str]]], segments: list[DocumentSegment], dataset: Dataset | |||||
| cls, keywords_list: Optional[list[list[str]]], segments: list[DocumentSegment], dataset: Dataset, doc_form: str | |||||
| ): | ): | ||||
| documents = [] | documents = [] | ||||
| for segment in segments: | |||||
| document = Document( | |||||
| page_content=segment.content, | |||||
| metadata={ | |||||
| "doc_id": segment.index_node_id, | |||||
| "doc_hash": segment.index_node_hash, | |||||
| "document_id": segment.document_id, | |||||
| "dataset_id": segment.dataset_id, | |||||
| }, | |||||
| ) | |||||
| documents.append(document) | |||||
| if dataset.indexing_technique == "high_quality": | |||||
| # save vector index | |||||
| vector = Vector(dataset=dataset) | |||||
| vector.add_texts(documents, duplicate_check=True) | |||||
| # save keyword index | |||||
| keyword = Keyword(dataset) | |||||
| for segment in segments: | |||||
| if doc_form == IndexType.PARENT_CHILD_INDEX: | |||||
| document = DatasetDocument.query.filter_by(id=segment.document_id).first() | |||||
| # get the process rule | |||||
| processing_rule = ( | |||||
| db.session.query(DatasetProcessRule) | |||||
| .filter(DatasetProcessRule.id == document.dataset_process_rule_id) | |||||
| .first() | |||||
| ) | |||||
| # get embedding model instance | |||||
| if dataset.indexing_technique == "high_quality": | |||||
| # check embedding model setting | |||||
| model_manager = ModelManager() | |||||
| if keywords_list and len(keywords_list) > 0: | |||||
| keyword.add_texts(documents, keywords_list=keywords_list) | |||||
| else: | |||||
| keyword.add_texts(documents) | |||||
| if dataset.embedding_model_provider: | |||||
| embedding_model_instance = model_manager.get_model_instance( | |||||
| tenant_id=dataset.tenant_id, | |||||
| provider=dataset.embedding_model_provider, | |||||
| model_type=ModelType.TEXT_EMBEDDING, | |||||
| model=dataset.embedding_model, | |||||
| ) | |||||
| else: | |||||
| embedding_model_instance = model_manager.get_default_model_instance( | |||||
| tenant_id=dataset.tenant_id, | |||||
| model_type=ModelType.TEXT_EMBEDDING, | |||||
| ) | |||||
| else: | |||||
| raise ValueError("The knowledge base index technique is not high quality!") | |||||
| cls.generate_child_chunks(segment, document, dataset, embedding_model_instance, processing_rule, False) | |||||
| else: | |||||
| document = Document( | |||||
| page_content=segment.content, | |||||
| metadata={ | |||||
| "doc_id": segment.index_node_id, | |||||
| "doc_hash": segment.index_node_hash, | |||||
| "document_id": segment.document_id, | |||||
| "dataset_id": segment.dataset_id, | |||||
| }, | |||||
| ) | |||||
| documents.append(document) | |||||
| if len(documents) > 0: | |||||
| index_processor = IndexProcessorFactory(doc_form).init_index_processor() | |||||
| index_processor.load(dataset, documents, with_keywords=True, keywords_list=keywords_list) | |||||
| @classmethod | @classmethod | ||||
| def update_segment_vector(cls, keywords: Optional[list[str]], segment: DocumentSegment, dataset: Dataset): | def update_segment_vector(cls, keywords: Optional[list[str]], segment: DocumentSegment, dataset: Dataset): | ||||
| keyword.add_texts([document], keywords_list=[keywords]) | keyword.add_texts([document], keywords_list=[keywords]) | ||||
| else: | else: | ||||
| keyword.add_texts([document]) | keyword.add_texts([document]) | ||||
| @classmethod | |||||
| def generate_child_chunks( | |||||
| cls, | |||||
| segment: DocumentSegment, | |||||
| dataset_document: Document, | |||||
| dataset: Dataset, | |||||
| embedding_model_instance: ModelInstance, | |||||
| processing_rule: DatasetProcessRule, | |||||
| regenerate: bool = False, | |||||
| ): | |||||
| index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor() | |||||
| if regenerate: | |||||
| # delete child chunks | |||||
| index_processor.clean(dataset, [segment.index_node_id], with_keywords=True, delete_child_chunks=True) | |||||
| # generate child chunks | |||||
| document = Document( | |||||
| page_content=segment.content, | |||||
| metadata={ | |||||
| "doc_id": segment.index_node_id, | |||||
| "doc_hash": segment.index_node_hash, | |||||
| "document_id": segment.document_id, | |||||
| "dataset_id": segment.dataset_id, | |||||
| }, | |||||
| ) | |||||
| # use full doc mode to generate segment's child chunk | |||||
| processing_rule_dict = processing_rule.to_dict() | |||||
| processing_rule_dict["rules"]["parent_mode"] = ParentMode.FULL_DOC.value | |||||
| documents = index_processor.transform( | |||||
| [document], | |||||
| embedding_model_instance=embedding_model_instance, | |||||
| process_rule=processing_rule_dict, | |||||
| tenant_id=dataset.tenant_id, | |||||
| doc_language=dataset_document.doc_language, | |||||
| ) | |||||
| # save child chunks | |||||
| if len(documents) > 0 and len(documents[0].children) > 0: | |||||
| index_processor.load(dataset, documents) | |||||
| for position, child_chunk in enumerate(documents[0].children, start=1): | |||||
| child_segment = ChildChunk( | |||||
| tenant_id=dataset.tenant_id, | |||||
| dataset_id=dataset.id, | |||||
| document_id=dataset_document.id, | |||||
| segment_id=segment.id, | |||||
| position=position, | |||||
| index_node_id=child_chunk.metadata["doc_id"], | |||||
| index_node_hash=child_chunk.metadata["doc_hash"], | |||||
| content=child_chunk.page_content, | |||||
| word_count=len(child_chunk.page_content), | |||||
| type="automatic", | |||||
| created_by=dataset_document.created_by, | |||||
| ) | |||||
| db.session.add(child_segment) | |||||
| db.session.commit() | |||||
| @classmethod | |||||
| def create_child_chunk_vector(cls, child_segment: ChildChunk, dataset: Dataset): | |||||
| child_document = Document( | |||||
| page_content=child_segment.content, | |||||
| metadata={ | |||||
| "doc_id": child_segment.index_node_id, | |||||
| "doc_hash": child_segment.index_node_hash, | |||||
| "document_id": child_segment.document_id, | |||||
| "dataset_id": child_segment.dataset_id, | |||||
| }, | |||||
| ) | |||||
| if dataset.indexing_technique == "high_quality": | |||||
| # save vector index | |||||
| vector = Vector(dataset=dataset) | |||||
| vector.add_texts([child_document], duplicate_check=True) | |||||
| @classmethod | |||||
| def update_child_chunk_vector( | |||||
| cls, | |||||
| new_child_chunks: list[ChildChunk], | |||||
| update_child_chunks: list[ChildChunk], | |||||
| delete_child_chunks: list[ChildChunk], | |||||
| dataset: Dataset, | |||||
| ): | |||||
| documents = [] | |||||
| delete_node_ids = [] | |||||
| for new_child_chunk in new_child_chunks: | |||||
| new_child_document = Document( | |||||
| page_content=new_child_chunk.content, | |||||
| metadata={ | |||||
| "doc_id": new_child_chunk.index_node_id, | |||||
| "doc_hash": new_child_chunk.index_node_hash, | |||||
| "document_id": new_child_chunk.document_id, | |||||
| "dataset_id": new_child_chunk.dataset_id, | |||||
| }, | |||||
| ) | |||||
| documents.append(new_child_document) | |||||
| for update_child_chunk in update_child_chunks: | |||||
| child_document = Document( | |||||
| page_content=update_child_chunk.content, | |||||
| metadata={ | |||||
| "doc_id": update_child_chunk.index_node_id, | |||||
| "doc_hash": update_child_chunk.index_node_hash, | |||||
| "document_id": update_child_chunk.document_id, | |||||
| "dataset_id": update_child_chunk.dataset_id, | |||||
| }, | |||||
| ) | |||||
| documents.append(child_document) | |||||
| delete_node_ids.append(update_child_chunk.index_node_id) | |||||
| for delete_child_chunk in delete_child_chunks: | |||||
| delete_node_ids.append(delete_child_chunk.index_node_id) | |||||
| if dataset.indexing_technique == "high_quality": | |||||
| # update vector index | |||||
| vector = Vector(dataset=dataset) | |||||
| if delete_node_ids: | |||||
| vector.delete_by_ids(delete_node_ids) | |||||
| if documents: | |||||
| vector.add_texts(documents, duplicate_check=True) | |||||
| @classmethod | |||||
| def delete_child_chunk_vector(cls, child_chunk: ChildChunk, dataset: Dataset): | |||||
| vector = Vector(dataset=dataset) | |||||
| vector.delete_by_ids([child_chunk.index_node_id]) |
| from celery import shared_task # type: ignore | from celery import shared_task # type: ignore | ||||
| from werkzeug.exceptions import NotFound | from werkzeug.exceptions import NotFound | ||||
| from core.rag.index_processor.constant.index_type import IndexType | |||||
| from core.rag.index_processor.index_processor_factory import IndexProcessorFactory | from core.rag.index_processor.index_processor_factory import IndexProcessorFactory | ||||
| from core.rag.models.document import Document | |||||
| from core.rag.models.document import ChildDocument, Document | |||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from extensions.ext_redis import redis_client | from extensions.ext_redis import redis_client | ||||
| from models.dataset import DatasetAutoDisableLog, DocumentSegment | |||||
| from models.dataset import Document as DatasetDocument | from models.dataset import Document as DatasetDocument | ||||
| from models.dataset import DocumentSegment | |||||
| @shared_task(queue="dataset") | @shared_task(queue="dataset") | ||||
| "dataset_id": segment.dataset_id, | "dataset_id": segment.dataset_id, | ||||
| }, | }, | ||||
| ) | ) | ||||
| if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: | |||||
| child_chunks = segment.child_chunks | |||||
| if child_chunks: | |||||
| child_documents = [] | |||||
| for child_chunk in child_chunks: | |||||
| child_document = ChildDocument( | |||||
| page_content=child_chunk.content, | |||||
| metadata={ | |||||
| "doc_id": child_chunk.index_node_id, | |||||
| "doc_hash": child_chunk.index_node_hash, | |||||
| "document_id": segment.document_id, | |||||
| "dataset_id": segment.dataset_id, | |||||
| }, | |||||
| ) | |||||
| child_documents.append(child_document) | |||||
| document.children = child_documents | |||||
| documents.append(document) | documents.append(document) | ||||
| dataset = dataset_document.dataset | dataset = dataset_document.dataset | ||||
| index_processor = IndexProcessorFactory(index_type).init_index_processor() | index_processor = IndexProcessorFactory(index_type).init_index_processor() | ||||
| index_processor.load(dataset, documents) | index_processor.load(dataset, documents) | ||||
| # delete auto disable log | |||||
| db.session.query(DatasetAutoDisableLog).filter( | |||||
| DatasetAutoDisableLog.document_id == dataset_document.id | |||||
| ).delete() | |||||
| db.session.commit() | |||||
| end_at = time.perf_counter() | end_at = time.perf_counter() | ||||
| logging.info( | logging.info( | ||||
| click.style( | click.style( |
| import logging | |||||
| import time | |||||
| import click | |||||
| from celery import shared_task | |||||
| from core.rag.index_processor.index_processor_factory import IndexProcessorFactory | |||||
| from core.tools.utils.web_reader_tool import get_image_upload_file_ids | |||||
| from extensions.ext_database import db | |||||
| from extensions.ext_storage import storage | |||||
| from models.dataset import Dataset, DocumentSegment | |||||
| from models.model import UploadFile | |||||
| @shared_task(queue="dataset") | |||||
| def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form: str, file_ids: list[str]): | |||||
| """ | |||||
| Clean document when document deleted. | |||||
| :param document_ids: document ids | |||||
| :param dataset_id: dataset id | |||||
| :param doc_form: doc_form | |||||
| :param file_ids: file ids | |||||
| Usage: clean_document_task.delay(document_id, dataset_id) | |||||
| """ | |||||
| logging.info(click.style("Start batch clean documents when documents deleted", fg="green")) | |||||
| start_at = time.perf_counter() | |||||
| try: | |||||
| dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() | |||||
| if not dataset: | |||||
| raise Exception("Document has no dataset") | |||||
| segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id.in_(document_ids)).all() | |||||
| # check segment is exist | |||||
| if segments: | |||||
| index_node_ids = [segment.index_node_id for segment in segments] | |||||
| index_processor = IndexProcessorFactory(doc_form).init_index_processor() | |||||
| index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) | |||||
| for segment in segments: | |||||
| image_upload_file_ids = get_image_upload_file_ids(segment.content) | |||||
| for upload_file_id in image_upload_file_ids: | |||||
| image_file = db.session.query(UploadFile).filter(UploadFile.id == upload_file_id).first() | |||||
| try: | |||||
| storage.delete(image_file.key) | |||||
| except Exception: | |||||
| logging.exception( | |||||
| "Delete image_files failed when storage deleted, \ | |||||
| image_upload_file_is: {}".format(upload_file_id) | |||||
| ) | |||||
| db.session.delete(image_file) | |||||
| db.session.delete(segment) | |||||
| db.session.commit() | |||||
| if file_ids: | |||||
| files = db.session.query(UploadFile).filter(UploadFile.id.in_(file_ids)).all() | |||||
| for file in files: | |||||
| try: | |||||
| storage.delete(file.key) | |||||
| except Exception: | |||||
| logging.exception("Delete file failed when document deleted, file_id: {}".format(file.id)) | |||||
| db.session.delete(file) | |||||
| db.session.commit() | |||||
| end_at = time.perf_counter() | |||||
| logging.info( | |||||
| click.style( | |||||
| "Cleaned documents when documents deleted latency: {}".format(end_at - start_at), | |||||
| fg="green", | |||||
| ) | |||||
| ) | |||||
| except Exception: | |||||
| logging.exception("Cleaned documents when documents deleted failed") |
| from celery import shared_task # type: ignore | from celery import shared_task # type: ignore | ||||
| from sqlalchemy import func | from sqlalchemy import func | ||||
| from core.indexing_runner import IndexingRunner | |||||
| from core.model_manager import ModelManager | from core.model_manager import ModelManager | ||||
| from core.model_runtime.entities.model_entities import ModelType | from core.model_runtime.entities.model_entities import ModelType | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from extensions.ext_redis import redis_client | from extensions.ext_redis import redis_client | ||||
| from libs import helper | from libs import helper | ||||
| from models.dataset import Dataset, Document, DocumentSegment | from models.dataset import Dataset, Document, DocumentSegment | ||||
| from services.vector_service import VectorService | |||||
| @shared_task(queue="dataset") | @shared_task(queue="dataset") | ||||
| dataset_document.word_count += word_count_change | dataset_document.word_count += word_count_change | ||||
| db.session.add(dataset_document) | db.session.add(dataset_document) | ||||
| # add index to db | # add index to db | ||||
| indexing_runner = IndexingRunner() | |||||
| indexing_runner.batch_add_segments(document_segments, dataset) | |||||
| VectorService.create_segments_vector(None, document_segments, dataset, dataset_document.doc_form) | |||||
| db.session.commit() | db.session.commit() | ||||
| redis_client.setex(indexing_cache_key, 600, "completed") | redis_client.setex(indexing_cache_key, 600, "completed") | ||||
| end_at = time.perf_counter() | end_at = time.perf_counter() |
| if doc_form is None: | if doc_form is None: | ||||
| raise ValueError("Index type must be specified.") | raise ValueError("Index type must be specified.") | ||||
| index_processor = IndexProcessorFactory(doc_form).init_index_processor() | index_processor = IndexProcessorFactory(doc_form).init_index_processor() | ||||
| index_processor.clean(dataset, None) | |||||
| index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=True) | |||||
| for document in documents: | for document in documents: | ||||
| db.session.delete(document) | db.session.delete(document) |
| if segments: | if segments: | ||||
| index_node_ids = [segment.index_node_id for segment in segments] | index_node_ids = [segment.index_node_id for segment in segments] | ||||
| index_processor = IndexProcessorFactory(doc_form).init_index_processor() | index_processor = IndexProcessorFactory(doc_form).init_index_processor() | ||||
| index_processor.clean(dataset, index_node_ids) | |||||
| index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) | |||||
| for segment in segments: | for segment in segments: | ||||
| image_upload_file_ids = get_image_upload_file_ids(segment.content) | image_upload_file_ids = get_image_upload_file_ids(segment.content) |
| segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() | segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() | ||||
| index_node_ids = [segment.index_node_id for segment in segments] | index_node_ids = [segment.index_node_id for segment in segments] | ||||
| index_processor.clean(dataset, index_node_ids) | |||||
| index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) | |||||
| for segment in segments: | for segment in segments: | ||||
| db.session.delete(segment) | db.session.delete(segment) |
| import click | import click | ||||
| from celery import shared_task # type: ignore | from celery import shared_task # type: ignore | ||||
| from core.rag.index_processor.constant.index_type import IndexType | |||||
| from core.rag.index_processor.index_processor_factory import IndexProcessorFactory | from core.rag.index_processor.index_processor_factory import IndexProcessorFactory | ||||
| from core.rag.models.document import Document | |||||
| from core.rag.models.document import ChildDocument, Document | |||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from models.dataset import Dataset, DocumentSegment | from models.dataset import Dataset, DocumentSegment | ||||
| from models.dataset import Document as DatasetDocument | from models.dataset import Document as DatasetDocument | ||||
| db.session.commit() | db.session.commit() | ||||
| # clean index | # clean index | ||||
| index_processor.clean(dataset, None, with_keywords=False) | |||||
| index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False) | |||||
| for dataset_document in dataset_documents: | for dataset_document in dataset_documents: | ||||
| # update from vector index | # update from vector index | ||||
| "dataset_id": segment.dataset_id, | "dataset_id": segment.dataset_id, | ||||
| }, | }, | ||||
| ) | ) | ||||
| if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: | |||||
| child_chunks = segment.child_chunks | |||||
| if child_chunks: | |||||
| child_documents = [] | |||||
| for child_chunk in child_chunks: | |||||
| child_document = ChildDocument( | |||||
| page_content=child_chunk.content, | |||||
| metadata={ | |||||
| "doc_id": child_chunk.index_node_id, | |||||
| "doc_hash": child_chunk.index_node_hash, | |||||
| "document_id": segment.document_id, | |||||
| "dataset_id": segment.dataset_id, | |||||
| }, | |||||
| ) | |||||
| child_documents.append(child_document) | |||||
| document.children = child_documents | |||||
| documents.append(document) | documents.append(document) | ||||
| # save vector index | # save vector index | ||||
| index_processor.load(dataset, documents, with_keywords=False) | index_processor.load(dataset, documents, with_keywords=False) |
| from core.rag.index_processor.index_processor_factory import IndexProcessorFactory | from core.rag.index_processor.index_processor_factory import IndexProcessorFactory | ||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from extensions.ext_redis import redis_client | |||||
| from models.dataset import Dataset, Document | from models.dataset import Dataset, Document | ||||
| @shared_task(queue="dataset") | @shared_task(queue="dataset") | ||||
| def delete_segment_from_index_task(segment_id: str, index_node_id: str, dataset_id: str, document_id: str): | |||||
| def delete_segment_from_index_task(index_node_ids: list, dataset_id: str, document_id: str): | |||||
| """ | """ | ||||
| Async Remove segment from index | Async Remove segment from index | ||||
| :param segment_id: | |||||
| :param index_node_id: | |||||
| :param index_node_ids: | |||||
| :param dataset_id: | :param dataset_id: | ||||
| :param document_id: | :param document_id: | ||||
| Usage: delete_segment_from_index_task.delay(segment_id) | |||||
| Usage: delete_segment_from_index_task.delay(segment_ids) | |||||
| """ | """ | ||||
| logging.info(click.style("Start delete segment from index: {}".format(segment_id), fg="green")) | |||||
| logging.info(click.style("Start delete segment from index", fg="green")) | |||||
| start_at = time.perf_counter() | start_at = time.perf_counter() | ||||
| indexing_cache_key = "segment_{}_delete_indexing".format(segment_id) | |||||
| try: | try: | ||||
| dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() | dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() | ||||
| if not dataset: | if not dataset: | ||||
| logging.info(click.style("Segment {} has no dataset, pass.".format(segment_id), fg="cyan")) | |||||
| return | return | ||||
| dataset_document = db.session.query(Document).filter(Document.id == document_id).first() | dataset_document = db.session.query(Document).filter(Document.id == document_id).first() | ||||
| if not dataset_document: | if not dataset_document: | ||||
| logging.info(click.style("Segment {} has no document, pass.".format(segment_id), fg="cyan")) | |||||
| return | return | ||||
| if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed": | if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed": | ||||
| logging.info(click.style("Segment {} document status is invalid, pass.".format(segment_id), fg="cyan")) | |||||
| return | return | ||||
| index_type = dataset_document.doc_form | index_type = dataset_document.doc_form | ||||
| index_processor = IndexProcessorFactory(index_type).init_index_processor() | index_processor = IndexProcessorFactory(index_type).init_index_processor() | ||||
| index_processor.clean(dataset, [index_node_id]) | |||||
| index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) | |||||
| end_at = time.perf_counter() | end_at = time.perf_counter() | ||||
| logging.info( | |||||
| click.style("Segment deleted from index: {} latency: {}".format(segment_id, end_at - start_at), fg="green") | |||||
| ) | |||||
| logging.info(click.style("Segment deleted from index latency: {}".format(end_at - start_at), fg="green")) | |||||
| except Exception: | except Exception: | ||||
| logging.exception("delete segment from index failed") | logging.exception("delete segment from index failed") | ||||
| finally: | |||||
| redis_client.delete(indexing_cache_key) |
| import logging | |||||
| import time | |||||
| import click | |||||
| from celery import shared_task | |||||
| from core.rag.index_processor.index_processor_factory import IndexProcessorFactory | |||||
| from extensions.ext_database import db | |||||
| from extensions.ext_redis import redis_client | |||||
| from models.dataset import Dataset, DocumentSegment | |||||
| from models.dataset import Document as DatasetDocument | |||||
| @shared_task(queue="dataset") | |||||
| def disable_segments_from_index_task(segment_ids: list, dataset_id: str, document_id: str): | |||||
| """ | |||||
| Async disable segments from index | |||||
| :param segment_ids: | |||||
| Usage: disable_segments_from_index_task.delay(segment_ids, dataset_id, document_id) | |||||
| """ | |||||
| start_at = time.perf_counter() | |||||
| dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() | |||||
| if not dataset: | |||||
| logging.info(click.style("Dataset {} not found, pass.".format(dataset_id), fg="cyan")) | |||||
| return | |||||
| dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == document_id).first() | |||||
| if not dataset_document: | |||||
| logging.info(click.style("Document {} not found, pass.".format(document_id), fg="cyan")) | |||||
| return | |||||
| if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed": | |||||
| logging.info(click.style("Document {} status is invalid, pass.".format(document_id), fg="cyan")) | |||||
| return | |||||
| # sync index processor | |||||
| index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor() | |||||
| segments = ( | |||||
| db.session.query(DocumentSegment) | |||||
| .filter( | |||||
| DocumentSegment.id.in_(segment_ids), | |||||
| DocumentSegment.dataset_id == dataset_id, | |||||
| DocumentSegment.document_id == document_id, | |||||
| ) | |||||
| .all() | |||||
| ) | |||||
| if not segments: | |||||
| return | |||||
| try: | |||||
| index_node_ids = [segment.index_node_id for segment in segments] | |||||
| index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=False) | |||||
| end_at = time.perf_counter() | |||||
| logging.info(click.style("Segments removed from index latency: {}".format(end_at - start_at), fg="green")) | |||||
| except Exception: | |||||
| # update segment error msg | |||||
| db.session.query(DocumentSegment).filter( | |||||
| DocumentSegment.id.in_(segment_ids), | |||||
| DocumentSegment.dataset_id == dataset_id, | |||||
| DocumentSegment.document_id == document_id, | |||||
| ).update( | |||||
| { | |||||
| "disabled_at": None, | |||||
| "disabled_by": None, | |||||
| "enabled": True, | |||||
| } | |||||
| ) | |||||
| db.session.commit() | |||||
| finally: | |||||
| for segment in segments: | |||||
| indexing_cache_key = "segment_{}_indexing".format(segment.id) | |||||
| redis_client.delete(indexing_cache_key) |
| index_node_ids = [segment.index_node_id for segment in segments] | index_node_ids = [segment.index_node_id for segment in segments] | ||||
| # delete from vector index | # delete from vector index | ||||
| index_processor.clean(dataset, index_node_ids) | |||||
| index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) | |||||
| for segment in segments: | for segment in segments: | ||||
| db.session.delete(segment) | db.session.delete(segment) |
| index_node_ids = [segment.index_node_id for segment in segments] | index_node_ids = [segment.index_node_id for segment in segments] | ||||
| # delete from vector index | # delete from vector index | ||||
| index_processor.clean(dataset, index_node_ids) | |||||
| index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) | |||||
| for segment in segments: | for segment in segments: | ||||
| db.session.delete(segment) | db.session.delete(segment) |
| if document: | if document: | ||||
| document.indexing_status = "error" | document.indexing_status = "error" | ||||
| document.error = str(e) | document.error = str(e) | ||||
| document.stopped_at = datetime.datetime.utcnow() | |||||
| document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) | |||||
| db.session.add(document) | db.session.add(document) | ||||
| db.session.commit() | db.session.commit() | ||||
| return | return | ||||
| index_node_ids = [segment.index_node_id for segment in segments] | index_node_ids = [segment.index_node_id for segment in segments] | ||||
| # delete from vector index | # delete from vector index | ||||
| index_processor.clean(dataset, index_node_ids) | |||||
| index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) | |||||
| for segment in segments: | for segment in segments: | ||||
| db.session.delete(segment) | db.session.delete(segment) | ||||
| db.session.commit() | db.session.commit() | ||||
| document.indexing_status = "parsing" | document.indexing_status = "parsing" | ||||
| document.processing_started_at = datetime.datetime.utcnow() | |||||
| document.processing_started_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) | |||||
| documents.append(document) | documents.append(document) | ||||
| db.session.add(document) | db.session.add(document) | ||||
| db.session.commit() | db.session.commit() |
| from celery import shared_task # type: ignore | from celery import shared_task # type: ignore | ||||
| from werkzeug.exceptions import NotFound | from werkzeug.exceptions import NotFound | ||||
| from core.rag.index_processor.constant.index_type import IndexType | |||||
| from core.rag.index_processor.index_processor_factory import IndexProcessorFactory | from core.rag.index_processor.index_processor_factory import IndexProcessorFactory | ||||
| from core.rag.models.document import Document | |||||
| from core.rag.models.document import ChildDocument, Document | |||||
| from extensions.ext_database import db | from extensions.ext_database import db | ||||
| from extensions.ext_redis import redis_client | from extensions.ext_redis import redis_client | ||||
| from models.dataset import DocumentSegment | from models.dataset import DocumentSegment | ||||
| return | return | ||||
| index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor() | index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor() | ||||
| if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: | |||||
| child_chunks = segment.child_chunks | |||||
| if child_chunks: | |||||
| child_documents = [] | |||||
| for child_chunk in child_chunks: | |||||
| child_document = ChildDocument( | |||||
| page_content=child_chunk.content, | |||||
| metadata={ | |||||
| "doc_id": child_chunk.index_node_id, | |||||
| "doc_hash": child_chunk.index_node_hash, | |||||
| "document_id": segment.document_id, | |||||
| "dataset_id": segment.dataset_id, | |||||
| }, | |||||
| ) | |||||
| child_documents.append(child_document) | |||||
| document.children = child_documents | |||||
| # save vector index | # save vector index | ||||
| index_processor.load(dataset, [document]) | index_processor.load(dataset, [document]) | ||||
| import datetime | |||||
| import logging | |||||
| import time | |||||
| import click | |||||
| from celery import shared_task | |||||
| from core.rag.index_processor.constant.index_type import IndexType | |||||
| from core.rag.index_processor.index_processor_factory import IndexProcessorFactory | |||||
| from core.rag.models.document import ChildDocument, Document | |||||
| from extensions.ext_database import db | |||||
| from extensions.ext_redis import redis_client | |||||
| from models.dataset import Dataset, DocumentSegment | |||||
| from models.dataset import Document as DatasetDocument | |||||
| @shared_task(queue="dataset") | |||||
| def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_id: str): | |||||
| """ | |||||
| Async enable segments to index | |||||
| :param segment_ids: | |||||
| Usage: enable_segments_to_index_task.delay(segment_ids) | |||||
| """ | |||||
| start_at = time.perf_counter() | |||||
| dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() | |||||
| if not dataset: | |||||
| logging.info(click.style("Dataset {} not found, pass.".format(dataset_id), fg="cyan")) | |||||
| return | |||||
| dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == document_id).first() | |||||
| if not dataset_document: | |||||
| logging.info(click.style("Document {} not found, pass.".format(document_id), fg="cyan")) | |||||
| return | |||||
| if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed": | |||||
| logging.info(click.style("Document {} status is invalid, pass.".format(document_id), fg="cyan")) | |||||
| return | |||||
| # sync index processor | |||||
| index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor() | |||||
| segments = ( | |||||
| db.session.query(DocumentSegment) | |||||
| .filter( | |||||
| DocumentSegment.id.in_(segment_ids), | |||||
| DocumentSegment.dataset_id == dataset_id, | |||||
| DocumentSegment.document_id == document_id, | |||||
| ) | |||||
| .all() | |||||
| ) | |||||
| if not segments: | |||||
| return | |||||
| try: | |||||
| documents = [] | |||||
| for segment in segments: | |||||
| document = Document( | |||||
| page_content=segment.content, | |||||
| metadata={ | |||||
| "doc_id": segment.index_node_id, | |||||
| "doc_hash": segment.index_node_hash, | |||||
| "document_id": document_id, | |||||
| "dataset_id": dataset_id, | |||||
| }, | |||||
| ) | |||||
| if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: | |||||
| child_chunks = segment.child_chunks | |||||
| if child_chunks: | |||||
| child_documents = [] | |||||
| for child_chunk in child_chunks: | |||||
| child_document = ChildDocument( | |||||
| page_content=child_chunk.content, | |||||
| metadata={ | |||||
| "doc_id": child_chunk.index_node_id, | |||||
| "doc_hash": child_chunk.index_node_hash, | |||||
| "document_id": document_id, | |||||
| "dataset_id": dataset_id, | |||||
| }, | |||||
| ) | |||||
| child_documents.append(child_document) | |||||
| document.children = child_documents | |||||
| documents.append(document) | |||||
| # save vector index | |||||
| index_processor.load(dataset, documents) | |||||
| end_at = time.perf_counter() | |||||
| logging.info(click.style("Segments enabled to index latency: {}".format(end_at - start_at), fg="green")) | |||||
| except Exception as e: | |||||
| logging.exception("enable segments to index failed") | |||||
| # update segment error msg | |||||
| db.session.query(DocumentSegment).filter( | |||||
| DocumentSegment.id.in_(segment_ids), | |||||
| DocumentSegment.dataset_id == dataset_id, | |||||
| DocumentSegment.document_id == document_id, | |||||
| ).update( | |||||
| { | |||||
| "error": str(e), | |||||
| "status": "error", | |||||
| "disabled_at": datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), | |||||
| "enabled": False, | |||||
| } | |||||
| ) | |||||
| db.session.commit() | |||||
| finally: | |||||
| for segment in segments: | |||||
| indexing_cache_key = "segment_{}_indexing".format(segment.id) | |||||
| redis_client.delete(indexing_cache_key) |
| index_node_ids = [segment.index_node_id for segment in segments] | index_node_ids = [segment.index_node_id for segment in segments] | ||||
| if index_node_ids: | if index_node_ids: | ||||
| try: | try: | ||||
| index_processor.clean(dataset, index_node_ids) | |||||
| index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=False) | |||||
| except Exception: | except Exception: | ||||
| logging.exception(f"clean dataset {dataset.id} from index failed") | logging.exception(f"clean dataset {dataset.id} from index failed") | ||||
| if document: | if document: | ||||
| document.indexing_status = "error" | document.indexing_status = "error" | ||||
| document.error = str(e) | document.error = str(e) | ||||
| document.stopped_at = datetime.datetime.utcnow() | |||||
| document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) | |||||
| db.session.add(document) | db.session.add(document) | ||||
| db.session.commit() | db.session.commit() | ||||
| redis_client.delete(retry_indexing_cache_key) | redis_client.delete(retry_indexing_cache_key) | ||||
| if segments: | if segments: | ||||
| index_node_ids = [segment.index_node_id for segment in segments] | index_node_ids = [segment.index_node_id for segment in segments] | ||||
| # delete from vector index | # delete from vector index | ||||
| index_processor.clean(dataset, index_node_ids) | |||||
| index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) | |||||
| for segment in segments: | |||||
| db.session.delete(segment) | |||||
| db.session.commit() | |||||
| for segment in segments: | |||||
| db.session.delete(segment) | |||||
| db.session.commit() | |||||
| document.indexing_status = "parsing" | document.indexing_status = "parsing" | ||||
| document.processing_started_at = datetime.datetime.utcnow() | |||||
| document.processing_started_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) | |||||
| db.session.add(document) | db.session.add(document) | ||||
| db.session.commit() | db.session.commit() | ||||
| except Exception as ex: | except Exception as ex: | ||||
| document.indexing_status = "error" | document.indexing_status = "error" | ||||
| document.error = str(ex) | document.error = str(ex) | ||||
| document.stopped_at = datetime.datetime.utcnow() | |||||
| document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) | |||||
| db.session.add(document) | db.session.add(document) | ||||
| db.session.commit() | db.session.commit() | ||||
| logging.info(click.style(str(ex), fg="yellow")) | logging.info(click.style(str(ex), fg="yellow")) |
| if document: | if document: | ||||
| document.indexing_status = "error" | document.indexing_status = "error" | ||||
| document.error = str(e) | document.error = str(e) | ||||
| document.stopped_at = datetime.datetime.utcnow() | |||||
| document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) | |||||
| db.session.add(document) | db.session.add(document) | ||||
| db.session.commit() | db.session.commit() | ||||
| redis_client.delete(sync_indexing_cache_key) | redis_client.delete(sync_indexing_cache_key) | ||||
| if segments: | if segments: | ||||
| index_node_ids = [segment.index_node_id for segment in segments] | index_node_ids = [segment.index_node_id for segment in segments] | ||||
| # delete from vector index | # delete from vector index | ||||
| index_processor.clean(dataset, index_node_ids) | |||||
| index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) | |||||
| for segment in segments: | |||||
| db.session.delete(segment) | |||||
| db.session.commit() | |||||
| for segment in segments: | |||||
| db.session.delete(segment) | |||||
| db.session.commit() | |||||
| document.indexing_status = "parsing" | document.indexing_status = "parsing" | ||||
| document.processing_started_at = datetime.datetime.utcnow() | |||||
| document.processing_started_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) | |||||
| db.session.add(document) | db.session.add(document) | ||||
| db.session.commit() | db.session.commit() | ||||
| except Exception as ex: | except Exception as ex: | ||||
| document.indexing_status = "error" | document.indexing_status = "error" | ||||
| document.error = str(ex) | document.error = str(ex) | ||||
| document.stopped_at = datetime.datetime.utcnow() | |||||
| document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) | |||||
| db.session.add(document) | db.session.add(document) | ||||
| db.session.commit() | db.session.commit() | ||||
| logging.info(click.style(str(ex), fg="yellow")) | logging.info(click.style(str(ex), fg="yellow")) |
| <!DOCTYPE html> | |||||
| <html lang="en"> | |||||
| <head> | |||||
| <meta charset="UTF-8"> | |||||
| <meta name="viewport" content="width=device-width, initial-scale=1.0"> | |||||
| <title>Documents Disabled Notification</title> | |||||
| <style> | |||||
| body { | |||||
| font-family: Arial, sans-serif; | |||||
| margin: 0; | |||||
| padding: 0; | |||||
| background-color: #f5f5f5; | |||||
| } | |||||
| .email-container { | |||||
| max-width: 600px; | |||||
| margin: 20px auto; | |||||
| background: #ffffff; | |||||
| border-radius: 10px; | |||||
| box-shadow: 0 2px 10px rgba(0, 0, 0, 0.1); | |||||
| overflow: hidden; | |||||
| } | |||||
| .header { | |||||
| background-color: #eef2fa; | |||||
| padding: 20px; | |||||
| text-align: center; | |||||
| } | |||||
| .header img { | |||||
| height: 40px; | |||||
| } | |||||
| .content { | |||||
| padding: 20px; | |||||
| line-height: 1.6; | |||||
| color: #333; | |||||
| } | |||||
| .content h1 { | |||||
| font-size: 24px; | |||||
| color: #222; | |||||
| } | |||||
| .content p { | |||||
| margin: 10px 0; | |||||
| } | |||||
| .content ul { | |||||
| padding-left: 20px; | |||||
| } | |||||
| .content ul li { | |||||
| margin-bottom: 10px; | |||||
| } | |||||
| .cta-button { | |||||
| display: block; | |||||
| margin: 20px auto; | |||||
| padding: 10px 20px; | |||||
| background-color: #4e89f9; | |||||
| color: #ffffff; | |||||
| text-align: center; | |||||
| text-decoration: none; | |||||
| border-radius: 5px; | |||||
| width: fit-content; | |||||
| } | |||||
| .footer { | |||||
| text-align: center; | |||||
| padding: 10px; | |||||
| font-size: 12px; | |||||
| color: #777; | |||||
| background-color: #f9f9f9; | |||||
| } | |||||
| </style> | |||||
| </head> | |||||
| <body> | |||||
| <div class="email-container"> | |||||
| <!-- Header --> | |||||
| <div class="header"> | |||||
| <img src="https://via.placeholder.com/150x40?text=Dify" alt="Dify Logo"> | |||||
| </div> | |||||
| <!-- Content --> | |||||
| <div class="content"> | |||||
| <h1>Some Documents in Your Knowledge Base Have Been Disabled</h1> | |||||
| <p>Dear {{userName}},</p> | |||||
| <p> | |||||
| We're sorry for the inconvenience. To ensure optimal performance, documents | |||||
| that haven’t been updated or accessed in the past 7 days have been disabled in | |||||
| your knowledge bases: | |||||
| </p> | |||||
| <ul> | |||||
| {{knowledge_details}} | |||||
| </ul> | |||||
| <p>You can re-enable them anytime.</p> | |||||
| <a href={{url}} class="cta-button">Re-enable in Dify</a> | |||||
| </div> | |||||
| <!-- Footer --> | |||||
| <div class="footer"> | |||||
| Sincerely,<br> | |||||
| The Dify Team | |||||
| </div> | |||||
| </div> | |||||
| </body> | |||||
| </html> |