| @@ -218,7 +218,7 @@ class DataSourceNotionApi(Resource): | |||
| args["doc_form"], | |||
| args["doc_language"], | |||
| ) | |||
| return response, 200 | |||
| return response.model_dump(), 200 | |||
| class DataSourceNotionDatasetSyncApi(Resource): | |||
| @@ -464,7 +464,7 @@ class DatasetIndexingEstimateApi(Resource): | |||
| except Exception as e: | |||
| raise IndexingEstimateError(str(e)) | |||
| return response, 200 | |||
| return response.model_dump(), 200 | |||
| class DatasetRelatedAppListApi(Resource): | |||
| @@ -733,6 +733,18 @@ class DatasetPermissionUserListApi(Resource): | |||
| }, 200 | |||
| class DatasetAutoDisableLogApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def get(self, dataset_id): | |||
| dataset_id_str = str(dataset_id) | |||
| dataset = DatasetService.get_dataset(dataset_id_str) | |||
| if dataset is None: | |||
| raise NotFound("Dataset not found.") | |||
| return DatasetService.get_dataset_auto_disable_logs(dataset_id_str), 200 | |||
| api.add_resource(DatasetListApi, "/datasets") | |||
| api.add_resource(DatasetApi, "/datasets/<uuid:dataset_id>") | |||
| api.add_resource(DatasetUseCheckApi, "/datasets/<uuid:dataset_id>/use-check") | |||
| @@ -747,3 +759,4 @@ api.add_resource(DatasetApiBaseUrlApi, "/datasets/api-base-info") | |||
| api.add_resource(DatasetRetrievalSettingApi, "/datasets/retrieval-setting") | |||
| api.add_resource(DatasetRetrievalSettingMockApi, "/datasets/retrieval-setting/<string:vector_type>") | |||
| api.add_resource(DatasetPermissionUserListApi, "/datasets/<uuid:dataset_id>/permission-part-users") | |||
| api.add_resource(DatasetAutoDisableLogApi, "/datasets/<uuid:dataset_id>/auto-disable-logs") | |||
| @@ -52,6 +52,7 @@ from fields.document_fields import ( | |||
| from libs.login import login_required | |||
| from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile | |||
| from services.dataset_service import DatasetService, DocumentService | |||
| from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig | |||
| from tasks.add_document_to_index_task import add_document_to_index_task | |||
| from tasks.remove_document_from_index_task import remove_document_from_index_task | |||
| @@ -255,20 +256,22 @@ class DatasetDocumentListApi(Resource): | |||
| parser.add_argument("duplicate", type=bool, default=True, nullable=False, location="json") | |||
| parser.add_argument("original_document_id", type=str, required=False, location="json") | |||
| parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json") | |||
| parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json") | |||
| parser.add_argument( | |||
| "doc_language", type=str, default="English", required=False, nullable=False, location="json" | |||
| ) | |||
| parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json") | |||
| args = parser.parse_args() | |||
| knowledge_config = KnowledgeConfig(**args) | |||
| if not dataset.indexing_technique and not args["indexing_technique"]: | |||
| if not dataset.indexing_technique and not knowledge_config.indexing_technique: | |||
| raise ValueError("indexing_technique is required.") | |||
| # validate args | |||
| DocumentService.document_create_args_validate(args) | |||
| DocumentService.document_create_args_validate(knowledge_config) | |||
| try: | |||
| documents, batch = DocumentService.save_document_with_dataset_id(dataset, args, current_user) | |||
| documents, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, current_user) | |||
| except ProviderTokenNotInitError as ex: | |||
| raise ProviderNotInitializeError(ex.description) | |||
| except QuotaExceededError: | |||
| @@ -278,6 +281,25 @@ class DatasetDocumentListApi(Resource): | |||
| return {"documents": documents, "batch": batch} | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def delete(self, dataset_id): | |||
| dataset_id = str(dataset_id) | |||
| dataset = DatasetService.get_dataset(dataset_id) | |||
| if dataset is None: | |||
| raise NotFound("Dataset not found.") | |||
| # check user's model setting | |||
| DatasetService.check_dataset_model_setting(dataset) | |||
| try: | |||
| document_ids = request.args.getlist("document_id") | |||
| DocumentService.delete_documents(dataset, document_ids) | |||
| except services.errors.document.DocumentIndexingError: | |||
| raise DocumentIndexingError("Cannot delete document during indexing.") | |||
| return {"result": "success"}, 204 | |||
| class DatasetInitApi(Resource): | |||
| @setup_required | |||
| @@ -313,9 +335,9 @@ class DatasetInitApi(Resource): | |||
| # The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator | |||
| if not current_user.is_dataset_editor: | |||
| raise Forbidden() | |||
| if args["indexing_technique"] == "high_quality": | |||
| if args["embedding_model"] is None or args["embedding_model_provider"] is None: | |||
| knowledge_config = KnowledgeConfig(**args) | |||
| if knowledge_config.indexing_technique == "high_quality": | |||
| if knowledge_config.embedding_model is None or knowledge_config.embedding_model_provider is None: | |||
| raise ValueError("embedding model and embedding model provider are required for high quality indexing.") | |||
| try: | |||
| model_manager = ModelManager() | |||
| @@ -334,11 +356,11 @@ class DatasetInitApi(Resource): | |||
| raise ProviderNotInitializeError(ex.description) | |||
| # validate args | |||
| DocumentService.document_create_args_validate(args) | |||
| DocumentService.document_create_args_validate(knowledge_config) | |||
| try: | |||
| dataset, documents, batch = DocumentService.save_document_without_dataset_id( | |||
| tenant_id=current_user.current_tenant_id, document_data=args, account=current_user | |||
| tenant_id=current_user.current_tenant_id, knowledge_config=knowledge_config, account=current_user | |||
| ) | |||
| except ProviderTokenNotInitError as ex: | |||
| raise ProviderNotInitializeError(ex.description) | |||
| @@ -409,7 +431,7 @@ class DocumentIndexingEstimateApi(DocumentResource): | |||
| except Exception as e: | |||
| raise IndexingEstimateError(str(e)) | |||
| return response | |||
| return response.model_dump(), 200 | |||
| class DocumentBatchIndexingEstimateApi(DocumentResource): | |||
| @@ -422,7 +444,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): | |||
| documents = self.get_batch_documents(dataset_id, batch) | |||
| response = {"tokens": 0, "total_price": 0, "currency": "USD", "total_segments": 0, "preview": []} | |||
| if not documents: | |||
| return response | |||
| return response, 200 | |||
| data_process_rule = documents[0].dataset_process_rule | |||
| data_process_rule_dict = data_process_rule.to_dict() | |||
| info_list = [] | |||
| @@ -509,7 +531,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): | |||
| raise ProviderNotInitializeError(ex.description) | |||
| except Exception as e: | |||
| raise IndexingEstimateError(str(e)) | |||
| return response | |||
| return response.model_dump(), 200 | |||
| class DocumentBatchIndexingStatusApi(DocumentResource): | |||
| @@ -582,7 +604,8 @@ class DocumentDetailApi(DocumentResource): | |||
| if metadata == "only": | |||
| response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata} | |||
| elif metadata == "without": | |||
| process_rules = DatasetService.get_process_rules(dataset_id) | |||
| dataset_process_rules = DatasetService.get_process_rules(dataset_id) | |||
| document_process_rules = document.dataset_process_rule.to_dict() | |||
| data_source_info = document.data_source_detail_dict | |||
| response = { | |||
| "id": document.id, | |||
| @@ -590,7 +613,8 @@ class DocumentDetailApi(DocumentResource): | |||
| "data_source_type": document.data_source_type, | |||
| "data_source_info": data_source_info, | |||
| "dataset_process_rule_id": document.dataset_process_rule_id, | |||
| "dataset_process_rule": process_rules, | |||
| "dataset_process_rule": dataset_process_rules, | |||
| "document_process_rule": document_process_rules, | |||
| "name": document.name, | |||
| "created_from": document.created_from, | |||
| "created_by": document.created_by, | |||
| @@ -613,7 +637,8 @@ class DocumentDetailApi(DocumentResource): | |||
| "doc_language": document.doc_language, | |||
| } | |||
| else: | |||
| process_rules = DatasetService.get_process_rules(dataset_id) | |||
| dataset_process_rules = DatasetService.get_process_rules(dataset_id) | |||
| document_process_rules = document.dataset_process_rule.to_dict() | |||
| data_source_info = document.data_source_detail_dict | |||
| response = { | |||
| "id": document.id, | |||
| @@ -621,7 +646,8 @@ class DocumentDetailApi(DocumentResource): | |||
| "data_source_type": document.data_source_type, | |||
| "data_source_info": data_source_info, | |||
| "dataset_process_rule_id": document.dataset_process_rule_id, | |||
| "dataset_process_rule": process_rules, | |||
| "dataset_process_rule": dataset_process_rules, | |||
| "document_process_rule": document_process_rules, | |||
| "name": document.name, | |||
| "created_from": document.created_from, | |||
| "created_by": document.created_by, | |||
| @@ -757,9 +783,8 @@ class DocumentStatusApi(DocumentResource): | |||
| @login_required | |||
| @account_initialization_required | |||
| @cloud_edition_billing_resource_check("vector_space") | |||
| def patch(self, dataset_id, document_id, action): | |||
| def patch(self, dataset_id, action): | |||
| dataset_id = str(dataset_id) | |||
| document_id = str(document_id) | |||
| dataset = DatasetService.get_dataset(dataset_id) | |||
| if dataset is None: | |||
| raise NotFound("Dataset not found.") | |||
| @@ -774,84 +799,79 @@ class DocumentStatusApi(DocumentResource): | |||
| # check user's permission | |||
| DatasetService.check_dataset_permission(dataset, current_user) | |||
| document = self.get_document(dataset_id, document_id) | |||
| document_ids = request.args.getlist("document_id") | |||
| for document_id in document_ids: | |||
| document = self.get_document(dataset_id, document_id) | |||
| indexing_cache_key = "document_{}_indexing".format(document.id) | |||
| cache_result = redis_client.get(indexing_cache_key) | |||
| if cache_result is not None: | |||
| raise InvalidActionError("Document is being indexed, please try again later") | |||
| indexing_cache_key = "document_{}_indexing".format(document.id) | |||
| cache_result = redis_client.get(indexing_cache_key) | |||
| if cache_result is not None: | |||
| raise InvalidActionError(f"Document:{document.name} is being indexed, please try again later") | |||
| if action == "enable": | |||
| if document.enabled: | |||
| raise InvalidActionError("Document already enabled.") | |||
| if action == "enable": | |||
| if document.enabled: | |||
| continue | |||
| document.enabled = True | |||
| document.disabled_at = None | |||
| document.disabled_by = None | |||
| document.updated_at = datetime.now(UTC).replace(tzinfo=None) | |||
| db.session.commit() | |||
| document.enabled = True | |||
| document.disabled_at = None | |||
| document.disabled_by = None | |||
| document.updated_at = datetime.now(UTC).replace(tzinfo=None) | |||
| db.session.commit() | |||
| # Set cache to prevent indexing the same document multiple times | |||
| redis_client.setex(indexing_cache_key, 600, 1) | |||
| # Set cache to prevent indexing the same document multiple times | |||
| redis_client.setex(indexing_cache_key, 600, 1) | |||
| add_document_to_index_task.delay(document_id) | |||
| add_document_to_index_task.delay(document_id) | |||
| elif action == "disable": | |||
| if not document.completed_at or document.indexing_status != "completed": | |||
| raise InvalidActionError(f"Document: {document.name} is not completed.") | |||
| if not document.enabled: | |||
| continue | |||
| return {"result": "success"}, 200 | |||
| document.enabled = False | |||
| document.disabled_at = datetime.now(UTC).replace(tzinfo=None) | |||
| document.disabled_by = current_user.id | |||
| document.updated_at = datetime.now(UTC).replace(tzinfo=None) | |||
| db.session.commit() | |||
| elif action == "disable": | |||
| if not document.completed_at or document.indexing_status != "completed": | |||
| raise InvalidActionError("Document is not completed.") | |||
| if not document.enabled: | |||
| raise InvalidActionError("Document already disabled.") | |||
| # Set cache to prevent indexing the same document multiple times | |||
| redis_client.setex(indexing_cache_key, 600, 1) | |||
| document.enabled = False | |||
| document.disabled_at = datetime.now(UTC).replace(tzinfo=None) | |||
| document.disabled_by = current_user.id | |||
| document.updated_at = datetime.now(UTC).replace(tzinfo=None) | |||
| db.session.commit() | |||
| remove_document_from_index_task.delay(document_id) | |||
| # Set cache to prevent indexing the same document multiple times | |||
| redis_client.setex(indexing_cache_key, 600, 1) | |||
| elif action == "archive": | |||
| if document.archived: | |||
| continue | |||
| remove_document_from_index_task.delay(document_id) | |||
| document.archived = True | |||
| document.archived_at = datetime.now(UTC).replace(tzinfo=None) | |||
| document.archived_by = current_user.id | |||
| document.updated_at = datetime.now(UTC).replace(tzinfo=None) | |||
| db.session.commit() | |||
| return {"result": "success"}, 200 | |||
| if document.enabled: | |||
| # Set cache to prevent indexing the same document multiple times | |||
| redis_client.setex(indexing_cache_key, 600, 1) | |||
| elif action == "archive": | |||
| if document.archived: | |||
| raise InvalidActionError("Document already archived.") | |||
| remove_document_from_index_task.delay(document_id) | |||
| document.archived = True | |||
| document.archived_at = datetime.now(UTC).replace(tzinfo=None) | |||
| document.archived_by = current_user.id | |||
| document.updated_at = datetime.now(UTC).replace(tzinfo=None) | |||
| db.session.commit() | |||
| elif action == "un_archive": | |||
| if not document.archived: | |||
| continue | |||
| document.archived = False | |||
| document.archived_at = None | |||
| document.archived_by = None | |||
| document.updated_at = datetime.now(UTC).replace(tzinfo=None) | |||
| db.session.commit() | |||
| if document.enabled: | |||
| # Set cache to prevent indexing the same document multiple times | |||
| redis_client.setex(indexing_cache_key, 600, 1) | |||
| remove_document_from_index_task.delay(document_id) | |||
| return {"result": "success"}, 200 | |||
| elif action == "un_archive": | |||
| if not document.archived: | |||
| raise InvalidActionError("Document is not archived.") | |||
| document.archived = False | |||
| document.archived_at = None | |||
| document.archived_by = None | |||
| document.updated_at = datetime.now(UTC).replace(tzinfo=None) | |||
| db.session.commit() | |||
| # Set cache to prevent indexing the same document multiple times | |||
| redis_client.setex(indexing_cache_key, 600, 1) | |||
| add_document_to_index_task.delay(document_id) | |||
| add_document_to_index_task.delay(document_id) | |||
| return {"result": "success"}, 200 | |||
| else: | |||
| raise InvalidActionError() | |||
| else: | |||
| raise InvalidActionError() | |||
| return {"result": "success"}, 200 | |||
| class DocumentPauseApi(DocumentResource): | |||
| @@ -1022,7 +1042,7 @@ api.add_resource( | |||
| ) | |||
| api.add_resource(DocumentDeleteApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>") | |||
| api.add_resource(DocumentMetadataApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/metadata") | |||
| api.add_resource(DocumentStatusApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/status/<string:action>") | |||
| api.add_resource(DocumentStatusApi, "/datasets/<uuid:dataset_id>/documents/status/<string:action>/batch") | |||
| api.add_resource(DocumentPauseApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/pause") | |||
| api.add_resource(DocumentRecoverApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/resume") | |||
| api.add_resource(DocumentRetryApi, "/datasets/<uuid:dataset_id>/retry") | |||
| @@ -1,5 +1,4 @@ | |||
| import uuid | |||
| from datetime import UTC, datetime | |||
| import pandas as pd | |||
| from flask import request | |||
| @@ -10,7 +9,13 @@ from werkzeug.exceptions import Forbidden, NotFound | |||
| import services | |||
| from controllers.console import api | |||
| from controllers.console.app.error import ProviderNotInitializeError | |||
| from controllers.console.datasets.error import InvalidActionError, NoFileUploadedError, TooManyFilesError | |||
| from controllers.console.datasets.error import ( | |||
| ChildChunkDeleteIndexError, | |||
| ChildChunkIndexingError, | |||
| InvalidActionError, | |||
| NoFileUploadedError, | |||
| TooManyFilesError, | |||
| ) | |||
| from controllers.console.wraps import ( | |||
| account_initialization_required, | |||
| cloud_edition_billing_knowledge_limit_check, | |||
| @@ -20,15 +25,15 @@ from controllers.console.wraps import ( | |||
| from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError | |||
| from core.model_manager import ModelManager | |||
| from core.model_runtime.entities.model_entities import ModelType | |||
| from extensions.ext_database import db | |||
| from extensions.ext_redis import redis_client | |||
| from fields.segment_fields import segment_fields | |||
| from fields.segment_fields import child_chunk_fields, segment_fields | |||
| from libs.login import login_required | |||
| from models import DocumentSegment | |||
| from models.dataset import ChildChunk, DocumentSegment | |||
| from services.dataset_service import DatasetService, DocumentService, SegmentService | |||
| from services.entities.knowledge_entities.knowledge_entities import ChildChunkUpdateArgs, SegmentUpdateArgs | |||
| from services.errors.chunk import ChildChunkDeleteIndexError as ChildChunkDeleteIndexServiceError | |||
| from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingServiceError | |||
| from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task | |||
| from tasks.disable_segment_from_index_task import disable_segment_from_index_task | |||
| from tasks.enable_segment_to_index_task import enable_segment_to_index_task | |||
| class DatasetDocumentSegmentListApi(Resource): | |||
| @@ -53,15 +58,16 @@ class DatasetDocumentSegmentListApi(Resource): | |||
| raise NotFound("Document not found.") | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("last_id", type=str, default=None, location="args") | |||
| parser.add_argument("limit", type=int, default=20, location="args") | |||
| parser.add_argument("status", type=str, action="append", default=[], location="args") | |||
| parser.add_argument("hit_count_gte", type=int, default=None, location="args") | |||
| parser.add_argument("enabled", type=str, default="all", location="args") | |||
| parser.add_argument("keyword", type=str, default=None, location="args") | |||
| parser.add_argument("page", type=int, default=1, location="args") | |||
| args = parser.parse_args() | |||
| last_id = args["last_id"] | |||
| page = args["page"] | |||
| limit = min(args["limit"], 100) | |||
| status_list = args["status"] | |||
| hit_count_gte = args["hit_count_gte"] | |||
| @@ -69,14 +75,7 @@ class DatasetDocumentSegmentListApi(Resource): | |||
| query = DocumentSegment.query.filter( | |||
| DocumentSegment.document_id == str(document_id), DocumentSegment.tenant_id == current_user.current_tenant_id | |||
| ) | |||
| if last_id is not None: | |||
| last_segment = db.session.get(DocumentSegment, str(last_id)) | |||
| if last_segment: | |||
| query = query.filter(DocumentSegment.position > last_segment.position) | |||
| else: | |||
| return {"data": [], "has_more": False, "limit": limit}, 200 | |||
| ).order_by(DocumentSegment.position.asc()) | |||
| if status_list: | |||
| query = query.filter(DocumentSegment.status.in_(status_list)) | |||
| @@ -93,21 +92,44 @@ class DatasetDocumentSegmentListApi(Resource): | |||
| elif args["enabled"].lower() == "false": | |||
| query = query.filter(DocumentSegment.enabled == False) | |||
| total = query.count() | |||
| segments = query.order_by(DocumentSegment.position).limit(limit + 1).all() | |||
| has_more = False | |||
| if len(segments) > limit: | |||
| has_more = True | |||
| segments = segments[:-1] | |||
| segments = query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False) | |||
| return { | |||
| "data": marshal(segments, segment_fields), | |||
| "doc_form": document.doc_form, | |||
| "has_more": has_more, | |||
| response = { | |||
| "data": marshal(segments.items, segment_fields), | |||
| "limit": limit, | |||
| "total": total, | |||
| }, 200 | |||
| "total": segments.total, | |||
| "total_pages": segments.pages, | |||
| "page": page, | |||
| } | |||
| return response, 200 | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def delete(self, dataset_id, document_id): | |||
| # check dataset | |||
| dataset_id = str(dataset_id) | |||
| dataset = DatasetService.get_dataset(dataset_id) | |||
| if not dataset: | |||
| raise NotFound("Dataset not found.") | |||
| # check user's model setting | |||
| DatasetService.check_dataset_model_setting(dataset) | |||
| # check document | |||
| document_id = str(document_id) | |||
| document = DocumentService.get_document(dataset_id, document_id) | |||
| if not document: | |||
| raise NotFound("Document not found.") | |||
| segment_ids = request.args.getlist("segment_id") | |||
| # The role of the current user in the ta table must be admin or owner | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| try: | |||
| DatasetService.check_dataset_permission(dataset, current_user) | |||
| except services.errors.account.NoPermissionError as e: | |||
| raise Forbidden(str(e)) | |||
| SegmentService.delete_segments(segment_ids, document, dataset) | |||
| return {"result": "success"}, 200 | |||
| class DatasetDocumentSegmentApi(Resource): | |||
| @@ -115,11 +137,15 @@ class DatasetDocumentSegmentApi(Resource): | |||
| @login_required | |||
| @account_initialization_required | |||
| @cloud_edition_billing_resource_check("vector_space") | |||
| def patch(self, dataset_id, segment_id, action): | |||
| def patch(self, dataset_id, document_id, action): | |||
| dataset_id = str(dataset_id) | |||
| dataset = DatasetService.get_dataset(dataset_id) | |||
| if not dataset: | |||
| raise NotFound("Dataset not found.") | |||
| document_id = str(document_id) | |||
| document = DocumentService.get_document(dataset_id, document_id) | |||
| if not document: | |||
| raise NotFound("Document not found.") | |||
| # check user's model setting | |||
| DatasetService.check_dataset_model_setting(dataset) | |||
| # The role of the current user in the ta table must be admin, owner, or editor | |||
| @@ -147,59 +173,17 @@ class DatasetDocumentSegmentApi(Resource): | |||
| ) | |||
| except ProviderTokenNotInitError as ex: | |||
| raise ProviderNotInitializeError(ex.description) | |||
| segment_ids = request.args.getlist("segment_id") | |||
| segment = DocumentSegment.query.filter( | |||
| DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id | |||
| ).first() | |||
| if not segment: | |||
| raise NotFound("Segment not found.") | |||
| if segment.status != "completed": | |||
| raise NotFound("Segment is not completed, enable or disable function is not allowed") | |||
| document_indexing_cache_key = "document_{}_indexing".format(segment.document_id) | |||
| document_indexing_cache_key = "document_{}_indexing".format(document.id) | |||
| cache_result = redis_client.get(document_indexing_cache_key) | |||
| if cache_result is not None: | |||
| raise InvalidActionError("Document is being indexed, please try again later") | |||
| indexing_cache_key = "segment_{}_indexing".format(segment.id) | |||
| cache_result = redis_client.get(indexing_cache_key) | |||
| if cache_result is not None: | |||
| raise InvalidActionError("Segment is being indexed, please try again later") | |||
| if action == "enable": | |||
| if segment.enabled: | |||
| raise InvalidActionError("Segment is already enabled.") | |||
| segment.enabled = True | |||
| segment.disabled_at = None | |||
| segment.disabled_by = None | |||
| db.session.commit() | |||
| # Set cache to prevent indexing the same segment multiple times | |||
| redis_client.setex(indexing_cache_key, 600, 1) | |||
| enable_segment_to_index_task.delay(segment.id) | |||
| return {"result": "success"}, 200 | |||
| elif action == "disable": | |||
| if not segment.enabled: | |||
| raise InvalidActionError("Segment is already disabled.") | |||
| segment.enabled = False | |||
| segment.disabled_at = datetime.now(UTC).replace(tzinfo=None) | |||
| segment.disabled_by = current_user.id | |||
| db.session.commit() | |||
| # Set cache to prevent indexing the same segment multiple times | |||
| redis_client.setex(indexing_cache_key, 600, 1) | |||
| disable_segment_from_index_task.delay(segment.id) | |||
| return {"result": "success"}, 200 | |||
| else: | |||
| raise InvalidActionError() | |||
| try: | |||
| SegmentService.update_segments_status(segment_ids, action, dataset, document) | |||
| except Exception as e: | |||
| raise InvalidActionError(str(e)) | |||
| return {"result": "success"}, 200 | |||
| class DatasetDocumentSegmentAddApi(Resource): | |||
| @@ -307,9 +291,12 @@ class DatasetDocumentSegmentUpdateApi(Resource): | |||
| parser.add_argument("content", type=str, required=True, nullable=False, location="json") | |||
| parser.add_argument("answer", type=str, required=False, nullable=True, location="json") | |||
| parser.add_argument("keywords", type=list, required=False, nullable=True, location="json") | |||
| parser.add_argument( | |||
| "regenerate_child_chunks", type=bool, required=False, nullable=True, default=False, location="json" | |||
| ) | |||
| args = parser.parse_args() | |||
| SegmentService.segment_create_args_validate(args, document) | |||
| segment = SegmentService.update_segment(args, segment, document, dataset) | |||
| segment = SegmentService.update_segment(SegmentUpdateArgs(**args), segment, document, dataset) | |||
| return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200 | |||
| @setup_required | |||
| @@ -412,8 +399,248 @@ class DatasetDocumentSegmentBatchImportApi(Resource): | |||
| return {"job_id": job_id, "job_status": cache_result.decode()}, 200 | |||
| class ChildChunkAddApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @cloud_edition_billing_resource_check("vector_space") | |||
| @cloud_edition_billing_knowledge_limit_check("add_segment") | |||
| def post(self, dataset_id, document_id, segment_id): | |||
| # check dataset | |||
| dataset_id = str(dataset_id) | |||
| dataset = DatasetService.get_dataset(dataset_id) | |||
| if not dataset: | |||
| raise NotFound("Dataset not found.") | |||
| # check document | |||
| document_id = str(document_id) | |||
| document = DocumentService.get_document(dataset_id, document_id) | |||
| if not document: | |||
| raise NotFound("Document not found.") | |||
| # check segment | |||
| segment_id = str(segment_id) | |||
| segment = DocumentSegment.query.filter( | |||
| DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id | |||
| ).first() | |||
| if not segment: | |||
| raise NotFound("Segment not found.") | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| # check embedding model setting | |||
| if dataset.indexing_technique == "high_quality": | |||
| try: | |||
| model_manager = ModelManager() | |||
| model_manager.get_model_instance( | |||
| tenant_id=current_user.current_tenant_id, | |||
| provider=dataset.embedding_model_provider, | |||
| model_type=ModelType.TEXT_EMBEDDING, | |||
| model=dataset.embedding_model, | |||
| ) | |||
| except LLMBadRequestError: | |||
| raise ProviderNotInitializeError( | |||
| "No Embedding Model available. Please configure a valid provider " | |||
| "in the Settings -> Model Provider." | |||
| ) | |||
| except ProviderTokenNotInitError as ex: | |||
| raise ProviderNotInitializeError(ex.description) | |||
| try: | |||
| DatasetService.check_dataset_permission(dataset, current_user) | |||
| except services.errors.account.NoPermissionError as e: | |||
| raise Forbidden(str(e)) | |||
| # validate args | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("content", type=str, required=True, nullable=False, location="json") | |||
| args = parser.parse_args() | |||
| try: | |||
| child_chunk = SegmentService.create_child_chunk(args.get("content"), segment, document, dataset) | |||
| except ChildChunkIndexingServiceError as e: | |||
| raise ChildChunkIndexingError(str(e)) | |||
| return {"data": marshal(child_chunk, child_chunk_fields)}, 200 | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def get(self, dataset_id, document_id, segment_id): | |||
| # check dataset | |||
| dataset_id = str(dataset_id) | |||
| dataset = DatasetService.get_dataset(dataset_id) | |||
| if not dataset: | |||
| raise NotFound("Dataset not found.") | |||
| # check user's model setting | |||
| DatasetService.check_dataset_model_setting(dataset) | |||
| # check document | |||
| document_id = str(document_id) | |||
| document = DocumentService.get_document(dataset_id, document_id) | |||
| if not document: | |||
| raise NotFound("Document not found.") | |||
| # check segment | |||
| segment_id = str(segment_id) | |||
| segment = DocumentSegment.query.filter( | |||
| DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id | |||
| ).first() | |||
| if not segment: | |||
| raise NotFound("Segment not found.") | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("limit", type=int, default=20, location="args") | |||
| parser.add_argument("keyword", type=str, default=None, location="args") | |||
| parser.add_argument("page", type=int, default=1, location="args") | |||
| args = parser.parse_args() | |||
| page = args["page"] | |||
| limit = min(args["limit"], 100) | |||
| keyword = args["keyword"] | |||
| child_chunks = SegmentService.get_child_chunks(segment_id, document_id, dataset_id, page, limit, keyword) | |||
| return { | |||
| "data": marshal(child_chunks.items, child_chunk_fields), | |||
| "total": child_chunks.total, | |||
| "total_pages": child_chunks.pages, | |||
| "page": page, | |||
| "limit": limit, | |||
| }, 200 | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @cloud_edition_billing_resource_check("vector_space") | |||
| def patch(self, dataset_id, document_id, segment_id): | |||
| # check dataset | |||
| dataset_id = str(dataset_id) | |||
| dataset = DatasetService.get_dataset(dataset_id) | |||
| if not dataset: | |||
| raise NotFound("Dataset not found.") | |||
| # check user's model setting | |||
| DatasetService.check_dataset_model_setting(dataset) | |||
| # check document | |||
| document_id = str(document_id) | |||
| document = DocumentService.get_document(dataset_id, document_id) | |||
| if not document: | |||
| raise NotFound("Document not found.") | |||
| # check segment | |||
| segment_id = str(segment_id) | |||
| segment = DocumentSegment.query.filter( | |||
| DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id | |||
| ).first() | |||
| if not segment: | |||
| raise NotFound("Segment not found.") | |||
| # The role of the current user in the ta table must be admin, owner, or editor | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| try: | |||
| DatasetService.check_dataset_permission(dataset, current_user) | |||
| except services.errors.account.NoPermissionError as e: | |||
| raise Forbidden(str(e)) | |||
| # validate args | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("chunks", type=list, required=True, nullable=False, location="json") | |||
| args = parser.parse_args() | |||
| try: | |||
| chunks = [ChildChunkUpdateArgs(**chunk) for chunk in args.get("chunks")] | |||
| child_chunks = SegmentService.update_child_chunks(chunks, segment, document, dataset) | |||
| except ChildChunkIndexingServiceError as e: | |||
| raise ChildChunkIndexingError(str(e)) | |||
| return {"data": marshal(child_chunks, child_chunk_fields)}, 200 | |||
| class ChildChunkUpdateApi(Resource): | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| def delete(self, dataset_id, document_id, segment_id, child_chunk_id): | |||
| # check dataset | |||
| dataset_id = str(dataset_id) | |||
| dataset = DatasetService.get_dataset(dataset_id) | |||
| if not dataset: | |||
| raise NotFound("Dataset not found.") | |||
| # check user's model setting | |||
| DatasetService.check_dataset_model_setting(dataset) | |||
| # check document | |||
| document_id = str(document_id) | |||
| document = DocumentService.get_document(dataset_id, document_id) | |||
| if not document: | |||
| raise NotFound("Document not found.") | |||
| # check segment | |||
| segment_id = str(segment_id) | |||
| segment = DocumentSegment.query.filter( | |||
| DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id | |||
| ).first() | |||
| if not segment: | |||
| raise NotFound("Segment not found.") | |||
| # check child chunk | |||
| child_chunk_id = str(child_chunk_id) | |||
| child_chunk = ChildChunk.query.filter( | |||
| ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id | |||
| ).first() | |||
| if not child_chunk: | |||
| raise NotFound("Child chunk not found.") | |||
| # The role of the current user in the ta table must be admin or owner | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| try: | |||
| DatasetService.check_dataset_permission(dataset, current_user) | |||
| except services.errors.account.NoPermissionError as e: | |||
| raise Forbidden(str(e)) | |||
| try: | |||
| SegmentService.delete_child_chunk(child_chunk, dataset) | |||
| except ChildChunkDeleteIndexServiceError as e: | |||
| raise ChildChunkDeleteIndexError(str(e)) | |||
| return {"result": "success"}, 200 | |||
| @setup_required | |||
| @login_required | |||
| @account_initialization_required | |||
| @cloud_edition_billing_resource_check("vector_space") | |||
| def patch(self, dataset_id, document_id, segment_id, child_chunk_id): | |||
| # check dataset | |||
| dataset_id = str(dataset_id) | |||
| dataset = DatasetService.get_dataset(dataset_id) | |||
| if not dataset: | |||
| raise NotFound("Dataset not found.") | |||
| # check user's model setting | |||
| DatasetService.check_dataset_model_setting(dataset) | |||
| # check document | |||
| document_id = str(document_id) | |||
| document = DocumentService.get_document(dataset_id, document_id) | |||
| if not document: | |||
| raise NotFound("Document not found.") | |||
| # check segment | |||
| segment_id = str(segment_id) | |||
| segment = DocumentSegment.query.filter( | |||
| DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id | |||
| ).first() | |||
| if not segment: | |||
| raise NotFound("Segment not found.") | |||
| # check child chunk | |||
| child_chunk_id = str(child_chunk_id) | |||
| child_chunk = ChildChunk.query.filter( | |||
| ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id | |||
| ).first() | |||
| if not child_chunk: | |||
| raise NotFound("Child chunk not found.") | |||
| # The role of the current user in the ta table must be admin or owner | |||
| if not current_user.is_editor: | |||
| raise Forbidden() | |||
| try: | |||
| DatasetService.check_dataset_permission(dataset, current_user) | |||
| except services.errors.account.NoPermissionError as e: | |||
| raise Forbidden(str(e)) | |||
| # validate args | |||
| parser = reqparse.RequestParser() | |||
| parser.add_argument("content", type=str, required=True, nullable=False, location="json") | |||
| args = parser.parse_args() | |||
| try: | |||
| child_chunk = SegmentService.update_child_chunk( | |||
| args.get("content"), child_chunk, segment, document, dataset | |||
| ) | |||
| except ChildChunkIndexingServiceError as e: | |||
| raise ChildChunkIndexingError(str(e)) | |||
| return {"data": marshal(child_chunk, child_chunk_fields)}, 200 | |||
| api.add_resource(DatasetDocumentSegmentListApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments") | |||
| api.add_resource(DatasetDocumentSegmentApi, "/datasets/<uuid:dataset_id>/segments/<uuid:segment_id>/<string:action>") | |||
| api.add_resource( | |||
| DatasetDocumentSegmentApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment/<string:action>" | |||
| ) | |||
| api.add_resource(DatasetDocumentSegmentAddApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment") | |||
| api.add_resource( | |||
| DatasetDocumentSegmentUpdateApi, | |||
| @@ -424,3 +651,11 @@ api.add_resource( | |||
| "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/batch_import", | |||
| "/datasets/batch_import_status/<uuid:job_id>", | |||
| ) | |||
| api.add_resource( | |||
| ChildChunkAddApi, | |||
| "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>/child_chunks", | |||
| ) | |||
| api.add_resource( | |||
| ChildChunkUpdateApi, | |||
| "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>/child_chunks/<uuid:child_chunk_id>", | |||
| ) | |||
| @@ -89,3 +89,15 @@ class IndexingEstimateError(BaseHTTPException): | |||
| error_code = "indexing_estimate_error" | |||
| description = "Knowledge indexing estimate failed: {message}" | |||
| code = 500 | |||
| class ChildChunkIndexingError(BaseHTTPException): | |||
| error_code = "child_chunk_indexing_error" | |||
| description = "Create child chunk index failed: {message}" | |||
| code = 500 | |||
| class ChildChunkDeleteIndexError(BaseHTTPException): | |||
| error_code = "child_chunk_delete_index_error" | |||
| description = "Delete child chunk index failed: {message}" | |||
| code = 500 | |||
| @@ -16,6 +16,7 @@ from extensions.ext_database import db | |||
| from fields.segment_fields import segment_fields | |||
| from models.dataset import Dataset, DocumentSegment | |||
| from services.dataset_service import DatasetService, DocumentService, SegmentService | |||
| from services.entities.knowledge_entities.knowledge_entities import SegmentUpdateArgs | |||
| class SegmentApi(DatasetApiResource): | |||
| @@ -193,7 +194,7 @@ class DatasetSegmentApi(DatasetApiResource): | |||
| args = parser.parse_args() | |||
| SegmentService.segment_create_args_validate(args["segment"], document) | |||
| segment = SegmentService.update_segment(args["segment"], segment, document, dataset) | |||
| segment = SegmentService.update_segment(SegmentUpdateArgs(**args["segment"]), segment, document, dataset) | |||
| return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200 | |||
| @@ -0,0 +1,19 @@ | |||
| from typing import Optional | |||
| from pydantic import BaseModel | |||
| class PreviewDetail(BaseModel): | |||
| content: str | |||
| child_chunks: Optional[list[str]] = None | |||
| class QAPreviewDetail(BaseModel): | |||
| question: str | |||
| answer: str | |||
| class IndexingEstimate(BaseModel): | |||
| total_segments: int | |||
| preview: list[PreviewDetail] | |||
| qa_preview: Optional[list[QAPreviewDetail]] = None | |||
| @@ -8,34 +8,34 @@ import time | |||
| import uuid | |||
| from typing import Any, Optional, cast | |||
| from flask import Flask, current_app | |||
| from flask import current_app | |||
| from flask_login import current_user # type: ignore | |||
| from sqlalchemy.orm.exc import ObjectDeletedError | |||
| from configs import dify_config | |||
| from core.entities.knowledge_entities import IndexingEstimate, PreviewDetail, QAPreviewDetail | |||
| from core.errors.error import ProviderTokenNotInitError | |||
| from core.llm_generator.llm_generator import LLMGenerator | |||
| from core.model_manager import ModelInstance, ModelManager | |||
| from core.model_runtime.entities.model_entities import ModelType | |||
| from core.rag.cleaner.clean_processor import CleanProcessor | |||
| from core.rag.datasource.keyword.keyword_factory import Keyword | |||
| from core.rag.docstore.dataset_docstore import DatasetDocumentStore | |||
| from core.rag.extractor.entity.extract_setting import ExtractSetting | |||
| from core.rag.index_processor.constant.index_type import IndexType | |||
| from core.rag.index_processor.index_processor_base import BaseIndexProcessor | |||
| from core.rag.index_processor.index_processor_factory import IndexProcessorFactory | |||
| from core.rag.models.document import Document | |||
| from core.rag.models.document import ChildDocument, Document | |||
| from core.rag.splitter.fixed_text_splitter import ( | |||
| EnhanceRecursiveCharacterTextSplitter, | |||
| FixedRecursiveCharacterTextSplitter, | |||
| ) | |||
| from core.rag.splitter.text_splitter import TextSplitter | |||
| from core.tools.utils.text_processing_utils import remove_leading_symbols | |||
| from core.tools.utils.web_reader_tool import get_image_upload_file_ids | |||
| from extensions.ext_database import db | |||
| from extensions.ext_redis import redis_client | |||
| from extensions.ext_storage import storage | |||
| from libs import helper | |||
| from models.dataset import Dataset, DatasetProcessRule, DocumentSegment | |||
| from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment | |||
| from models.dataset import Document as DatasetDocument | |||
| from models.model import UploadFile | |||
| from services.feature_service import FeatureService | |||
| @@ -115,6 +115,9 @@ class IndexingRunner: | |||
| for document_segment in document_segments: | |||
| db.session.delete(document_segment) | |||
| if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: | |||
| # delete child chunks | |||
| db.session.query(ChildChunk).filter(ChildChunk.segment_id == document_segment.id).delete() | |||
| db.session.commit() | |||
| # get the process rule | |||
| processing_rule = ( | |||
| @@ -183,7 +186,22 @@ class IndexingRunner: | |||
| "dataset_id": document_segment.dataset_id, | |||
| }, | |||
| ) | |||
| if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: | |||
| child_chunks = document_segment.child_chunks | |||
| if child_chunks: | |||
| child_documents = [] | |||
| for child_chunk in child_chunks: | |||
| child_document = ChildDocument( | |||
| page_content=child_chunk.content, | |||
| metadata={ | |||
| "doc_id": child_chunk.index_node_id, | |||
| "doc_hash": child_chunk.index_node_hash, | |||
| "document_id": document_segment.document_id, | |||
| "dataset_id": document_segment.dataset_id, | |||
| }, | |||
| ) | |||
| child_documents.append(child_document) | |||
| document.children = child_documents | |||
| documents.append(document) | |||
| # build index | |||
| @@ -222,7 +240,7 @@ class IndexingRunner: | |||
| doc_language: str = "English", | |||
| dataset_id: Optional[str] = None, | |||
| indexing_technique: str = "economy", | |||
| ) -> dict: | |||
| ) -> IndexingEstimate: | |||
| """ | |||
| Estimate the indexing for the document. | |||
| """ | |||
| @@ -258,31 +276,38 @@ class IndexingRunner: | |||
| tenant_id=tenant_id, | |||
| model_type=ModelType.TEXT_EMBEDDING, | |||
| ) | |||
| preview_texts: list[str] = [] | |||
| preview_texts = [] | |||
| total_segments = 0 | |||
| index_type = doc_form | |||
| index_processor = IndexProcessorFactory(index_type).init_index_processor() | |||
| all_text_docs = [] | |||
| for extract_setting in extract_settings: | |||
| # extract | |||
| text_docs = index_processor.extract(extract_setting, process_rule_mode=tmp_processing_rule["mode"]) | |||
| all_text_docs.extend(text_docs) | |||
| processing_rule = DatasetProcessRule( | |||
| mode=tmp_processing_rule["mode"], rules=json.dumps(tmp_processing_rule["rules"]) | |||
| ) | |||
| # get splitter | |||
| splitter = self._get_splitter(processing_rule, embedding_model_instance) | |||
| # split to documents | |||
| documents = self._split_to_documents_for_estimate( | |||
| text_docs=text_docs, splitter=splitter, processing_rule=processing_rule | |||
| text_docs = index_processor.extract(extract_setting, process_rule_mode=tmp_processing_rule["mode"]) | |||
| documents = index_processor.transform( | |||
| text_docs, | |||
| embedding_model_instance=embedding_model_instance, | |||
| process_rule=processing_rule.to_dict(), | |||
| tenant_id=current_user.current_tenant_id, | |||
| doc_language=doc_language, | |||
| preview=True, | |||
| ) | |||
| total_segments += len(documents) | |||
| for document in documents: | |||
| if len(preview_texts) < 5: | |||
| preview_texts.append(document.page_content) | |||
| if len(preview_texts) < 10: | |||
| if doc_form and doc_form == "qa_model": | |||
| preview_detail = QAPreviewDetail( | |||
| question=document.page_content, answer=document.metadata.get("answer") | |||
| ) | |||
| preview_texts.append(preview_detail) | |||
| else: | |||
| preview_detail = PreviewDetail(content=document.page_content) | |||
| if document.children: | |||
| preview_detail.child_chunks = [child.page_content for child in document.children] | |||
| preview_texts.append(preview_detail) | |||
| # delete image files and related db records | |||
| image_upload_file_ids = get_image_upload_file_ids(document.page_content) | |||
| @@ -299,15 +324,8 @@ class IndexingRunner: | |||
| db.session.delete(image_file) | |||
| if doc_form and doc_form == "qa_model": | |||
| if len(preview_texts) > 0: | |||
| # qa model document | |||
| response = LLMGenerator.generate_qa_document( | |||
| current_user.current_tenant_id, preview_texts[0], doc_language | |||
| ) | |||
| document_qa_list = self.format_split_text(response) | |||
| return {"total_segments": total_segments * 20, "qa_preview": document_qa_list, "preview": preview_texts} | |||
| return {"total_segments": total_segments, "preview": preview_texts} | |||
| return IndexingEstimate(total_segments=total_segments * 20, qa_preview=preview_texts, preview=[]) | |||
| return IndexingEstimate(total_segments=total_segments, preview=preview_texts) | |||
| def _extract( | |||
| self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict | |||
| @@ -401,31 +419,26 @@ class IndexingRunner: | |||
| @staticmethod | |||
| def _get_splitter( | |||
| processing_rule: DatasetProcessRule, embedding_model_instance: Optional[ModelInstance] | |||
| processing_rule_mode: str, | |||
| max_tokens: int, | |||
| chunk_overlap: int, | |||
| separator: str, | |||
| embedding_model_instance: Optional[ModelInstance], | |||
| ) -> TextSplitter: | |||
| """ | |||
| Get the NodeParser object according to the processing rule. | |||
| """ | |||
| character_splitter: TextSplitter | |||
| if processing_rule.mode == "custom": | |||
| if processing_rule_mode in ["custom", "hierarchical"]: | |||
| # The user-defined segmentation rule | |||
| rules = json.loads(processing_rule.rules) | |||
| segmentation = rules["segmentation"] | |||
| max_segmentation_tokens_length = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH | |||
| if segmentation["max_tokens"] < 50 or segmentation["max_tokens"] > max_segmentation_tokens_length: | |||
| if max_tokens < 50 or max_tokens > max_segmentation_tokens_length: | |||
| raise ValueError(f"Custom segment length should be between 50 and {max_segmentation_tokens_length}.") | |||
| separator = segmentation["separator"] | |||
| if separator: | |||
| separator = separator.replace("\\n", "\n") | |||
| if segmentation.get("chunk_overlap"): | |||
| chunk_overlap = segmentation["chunk_overlap"] | |||
| else: | |||
| chunk_overlap = 0 | |||
| character_splitter = FixedRecursiveCharacterTextSplitter.from_encoder( | |||
| chunk_size=segmentation["max_tokens"], | |||
| chunk_size=max_tokens, | |||
| chunk_overlap=chunk_overlap, | |||
| fixed_separator=separator, | |||
| separators=["\n\n", "。", ". ", " ", ""], | |||
| @@ -443,142 +456,6 @@ class IndexingRunner: | |||
| return character_splitter | |||
| def _step_split( | |||
| self, | |||
| text_docs: list[Document], | |||
| splitter: TextSplitter, | |||
| dataset: Dataset, | |||
| dataset_document: DatasetDocument, | |||
| processing_rule: DatasetProcessRule, | |||
| ) -> list[Document]: | |||
| """ | |||
| Split the text documents into documents and save them to the document segment. | |||
| """ | |||
| documents = self._split_to_documents( | |||
| text_docs=text_docs, | |||
| splitter=splitter, | |||
| processing_rule=processing_rule, | |||
| tenant_id=dataset.tenant_id, | |||
| document_form=dataset_document.doc_form, | |||
| document_language=dataset_document.doc_language, | |||
| ) | |||
| # save node to document segment | |||
| doc_store = DatasetDocumentStore( | |||
| dataset=dataset, user_id=dataset_document.created_by, document_id=dataset_document.id | |||
| ) | |||
| # add document segments | |||
| doc_store.add_documents(documents) | |||
| # update document status to indexing | |||
| cur_time = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) | |||
| self._update_document_index_status( | |||
| document_id=dataset_document.id, | |||
| after_indexing_status="indexing", | |||
| extra_update_params={ | |||
| DatasetDocument.cleaning_completed_at: cur_time, | |||
| DatasetDocument.splitting_completed_at: cur_time, | |||
| }, | |||
| ) | |||
| # update segment status to indexing | |||
| self._update_segments_by_document( | |||
| dataset_document_id=dataset_document.id, | |||
| update_params={ | |||
| DocumentSegment.status: "indexing", | |||
| DocumentSegment.indexing_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), | |||
| }, | |||
| ) | |||
| return documents | |||
| def _split_to_documents( | |||
| self, | |||
| text_docs: list[Document], | |||
| splitter: TextSplitter, | |||
| processing_rule: DatasetProcessRule, | |||
| tenant_id: str, | |||
| document_form: str, | |||
| document_language: str, | |||
| ) -> list[Document]: | |||
| """ | |||
| Split the text documents into nodes. | |||
| """ | |||
| all_documents: list[Document] = [] | |||
| all_qa_documents: list[Document] = [] | |||
| for text_doc in text_docs: | |||
| # document clean | |||
| document_text = self._document_clean(text_doc.page_content, processing_rule) | |||
| text_doc.page_content = document_text | |||
| # parse document to nodes | |||
| documents = splitter.split_documents([text_doc]) | |||
| split_documents = [] | |||
| for document_node in documents: | |||
| if document_node.page_content.strip(): | |||
| if document_node.metadata is not None: | |||
| doc_id = str(uuid.uuid4()) | |||
| hash = helper.generate_text_hash(document_node.page_content) | |||
| document_node.metadata["doc_id"] = doc_id | |||
| document_node.metadata["doc_hash"] = hash | |||
| # delete Splitter character | |||
| page_content = document_node.page_content | |||
| document_node.page_content = remove_leading_symbols(page_content) | |||
| if document_node.page_content: | |||
| split_documents.append(document_node) | |||
| all_documents.extend(split_documents) | |||
| # processing qa document | |||
| if document_form == "qa_model": | |||
| for i in range(0, len(all_documents), 10): | |||
| threads = [] | |||
| sub_documents = all_documents[i : i + 10] | |||
| for doc in sub_documents: | |||
| document_format_thread = threading.Thread( | |||
| target=self.format_qa_document, | |||
| kwargs={ | |||
| "flask_app": current_app._get_current_object(), # type: ignore | |||
| "tenant_id": tenant_id, | |||
| "document_node": doc, | |||
| "all_qa_documents": all_qa_documents, | |||
| "document_language": document_language, | |||
| }, | |||
| ) | |||
| threads.append(document_format_thread) | |||
| document_format_thread.start() | |||
| for thread in threads: | |||
| thread.join() | |||
| return all_qa_documents | |||
| return all_documents | |||
| def format_qa_document(self, flask_app: Flask, tenant_id: str, document_node, all_qa_documents, document_language): | |||
| format_documents = [] | |||
| if document_node.page_content is None or not document_node.page_content.strip(): | |||
| return | |||
| with flask_app.app_context(): | |||
| try: | |||
| # qa model document | |||
| response = LLMGenerator.generate_qa_document(tenant_id, document_node.page_content, document_language) | |||
| document_qa_list = self.format_split_text(response) | |||
| qa_documents = [] | |||
| for result in document_qa_list: | |||
| qa_document = Document( | |||
| page_content=result["question"], metadata=document_node.metadata.model_copy() | |||
| ) | |||
| if qa_document.metadata is not None: | |||
| doc_id = str(uuid.uuid4()) | |||
| hash = helper.generate_text_hash(result["question"]) | |||
| qa_document.metadata["answer"] = result["answer"] | |||
| qa_document.metadata["doc_id"] = doc_id | |||
| qa_document.metadata["doc_hash"] = hash | |||
| qa_documents.append(qa_document) | |||
| format_documents.extend(qa_documents) | |||
| except Exception as e: | |||
| logging.exception("Failed to format qa document") | |||
| all_qa_documents.extend(format_documents) | |||
| def _split_to_documents_for_estimate( | |||
| self, text_docs: list[Document], splitter: TextSplitter, processing_rule: DatasetProcessRule | |||
| ) -> list[Document]: | |||
| @@ -624,11 +501,11 @@ class IndexingRunner: | |||
| return document_text | |||
| @staticmethod | |||
| def format_split_text(text): | |||
| def format_split_text(text: str) -> list[QAPreviewDetail]: | |||
| regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q\d+:|$)" | |||
| matches = re.findall(regex, text, re.UNICODE) | |||
| return [{"question": q, "answer": re.sub(r"\n\s*", "\n", a.strip())} for q, a in matches if q and a] | |||
| return [QAPreviewDetail(question=q, answer=re.sub(r"\n\s*", "\n", a.strip())) for q, a in matches if q and a] | |||
| def _load( | |||
| self, | |||
| @@ -654,13 +531,14 @@ class IndexingRunner: | |||
| indexing_start_at = time.perf_counter() | |||
| tokens = 0 | |||
| chunk_size = 10 | |||
| if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX: | |||
| # create keyword index | |||
| create_keyword_thread = threading.Thread( | |||
| target=self._process_keyword_index, | |||
| args=(current_app._get_current_object(), dataset.id, dataset_document.id, documents), | |||
| ) | |||
| create_keyword_thread.start() | |||
| # create keyword index | |||
| create_keyword_thread = threading.Thread( | |||
| target=self._process_keyword_index, | |||
| args=(current_app._get_current_object(), dataset.id, dataset_document.id, documents), # type: ignore | |||
| ) | |||
| create_keyword_thread.start() | |||
| if dataset.indexing_technique == "high_quality": | |||
| with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor: | |||
| futures = [] | |||
| @@ -680,8 +558,8 @@ class IndexingRunner: | |||
| for future in futures: | |||
| tokens += future.result() | |||
| create_keyword_thread.join() | |||
| if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX: | |||
| create_keyword_thread.join() | |||
| indexing_end_at = time.perf_counter() | |||
| # update document status to completed | |||
| @@ -793,28 +671,6 @@ class IndexingRunner: | |||
| DocumentSegment.query.filter_by(document_id=dataset_document_id).update(update_params) | |||
| db.session.commit() | |||
| @staticmethod | |||
| def batch_add_segments(segments: list[DocumentSegment], dataset: Dataset): | |||
| """ | |||
| Batch add segments index processing | |||
| """ | |||
| documents = [] | |||
| for segment in segments: | |||
| document = Document( | |||
| page_content=segment.content, | |||
| metadata={ | |||
| "doc_id": segment.index_node_id, | |||
| "doc_hash": segment.index_node_hash, | |||
| "document_id": segment.document_id, | |||
| "dataset_id": segment.dataset_id, | |||
| }, | |||
| ) | |||
| documents.append(document) | |||
| # save vector index | |||
| index_type = dataset.doc_form | |||
| index_processor = IndexProcessorFactory(index_type).init_index_processor() | |||
| index_processor.load(dataset, documents) | |||
| def _transform( | |||
| self, | |||
| index_processor: BaseIndexProcessor, | |||
| @@ -856,7 +712,7 @@ class IndexingRunner: | |||
| ) | |||
| # add document segments | |||
| doc_store.add_documents(documents) | |||
| doc_store.add_documents(docs=documents, save_child=dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX) | |||
| # update document status to indexing | |||
| cur_time = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) | |||
| @@ -6,11 +6,14 @@ from flask import Flask, current_app | |||
| from core.rag.data_post_processor.data_post_processor import DataPostProcessor | |||
| from core.rag.datasource.keyword.keyword_factory import Keyword | |||
| from core.rag.datasource.vdb.vector_factory import Vector | |||
| from core.rag.embedding.retrieval import RetrievalSegments | |||
| from core.rag.index_processor.constant.index_type import IndexType | |||
| from core.rag.models.document import Document | |||
| from core.rag.rerank.rerank_type import RerankMode | |||
| from core.rag.retrieval.retrieval_methods import RetrievalMethod | |||
| from extensions.ext_database import db | |||
| from models.dataset import Dataset | |||
| from models.dataset import ChildChunk, Dataset, DocumentSegment | |||
| from models.dataset import Document as DatasetDocument | |||
| from services.external_knowledge_service import ExternalDatasetService | |||
| default_retrieval_model = { | |||
| @@ -248,3 +251,88 @@ class RetrievalService: | |||
| @staticmethod | |||
| def escape_query_for_search(query: str) -> str: | |||
| return query.replace('"', '\\"') | |||
| @staticmethod | |||
| def format_retrieval_documents(documents: list[Document]) -> list[RetrievalSegments]: | |||
| records = [] | |||
| include_segment_ids = [] | |||
| segment_child_map = {} | |||
| for document in documents: | |||
| document_id = document.metadata["document_id"] | |||
| dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == document_id).first() | |||
| if dataset_document and dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: | |||
| child_index_node_id = document.metadata["doc_id"] | |||
| result = ( | |||
| db.session.query(ChildChunk, DocumentSegment) | |||
| .join(DocumentSegment, ChildChunk.segment_id == DocumentSegment.id) | |||
| .filter( | |||
| ChildChunk.index_node_id == child_index_node_id, | |||
| DocumentSegment.dataset_id == dataset_document.dataset_id, | |||
| DocumentSegment.enabled == True, | |||
| DocumentSegment.status == "completed", | |||
| ) | |||
| .first() | |||
| ) | |||
| if result: | |||
| child_chunk, segment = result | |||
| if not segment: | |||
| continue | |||
| if segment.id not in include_segment_ids: | |||
| include_segment_ids.append(segment.id) | |||
| child_chunk_detail = { | |||
| "id": child_chunk.id, | |||
| "content": child_chunk.content, | |||
| "position": child_chunk.position, | |||
| "score": document.metadata.get("score", 0.0), | |||
| } | |||
| map_detail = { | |||
| "max_score": document.metadata.get("score", 0.0), | |||
| "child_chunks": [child_chunk_detail], | |||
| } | |||
| segment_child_map[segment.id] = map_detail | |||
| record = { | |||
| "segment": segment, | |||
| } | |||
| records.append(record) | |||
| else: | |||
| child_chunk_detail = { | |||
| "id": child_chunk.id, | |||
| "content": child_chunk.content, | |||
| "position": child_chunk.position, | |||
| "score": document.metadata.get("score", 0.0), | |||
| } | |||
| segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail) | |||
| segment_child_map[segment.id]["max_score"] = max( | |||
| segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0) | |||
| ) | |||
| else: | |||
| continue | |||
| else: | |||
| index_node_id = document.metadata["doc_id"] | |||
| segment = ( | |||
| db.session.query(DocumentSegment) | |||
| .filter( | |||
| DocumentSegment.dataset_id == dataset_document.dataset_id, | |||
| DocumentSegment.enabled == True, | |||
| DocumentSegment.status == "completed", | |||
| DocumentSegment.index_node_id == index_node_id, | |||
| ) | |||
| .first() | |||
| ) | |||
| if not segment: | |||
| continue | |||
| include_segment_ids.append(segment.id) | |||
| record = { | |||
| "segment": segment, | |||
| "score": document.metadata.get("score", None), | |||
| } | |||
| records.append(record) | |||
| for record in records: | |||
| if record["segment"].id in segment_child_map: | |||
| record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks", None) | |||
| record["score"] = segment_child_map[record["segment"].id]["max_score"] | |||
| return [RetrievalSegments(**record) for record in records] | |||
| @@ -7,7 +7,7 @@ from core.model_manager import ModelManager | |||
| from core.model_runtime.entities.model_entities import ModelType | |||
| from core.rag.models.document import Document | |||
| from extensions.ext_database import db | |||
| from models.dataset import Dataset, DocumentSegment | |||
| from models.dataset import ChildChunk, Dataset, DocumentSegment | |||
| class DatasetDocumentStore: | |||
| @@ -60,7 +60,7 @@ class DatasetDocumentStore: | |||
| return output | |||
| def add_documents(self, docs: Sequence[Document], allow_update: bool = True) -> None: | |||
| def add_documents(self, docs: Sequence[Document], allow_update: bool = True, save_child: bool = False) -> None: | |||
| max_position = ( | |||
| db.session.query(func.max(DocumentSegment.position)) | |||
| .filter(DocumentSegment.document_id == self._document_id) | |||
| @@ -120,6 +120,23 @@ class DatasetDocumentStore: | |||
| segment_document.answer = doc.metadata.pop("answer", "") | |||
| db.session.add(segment_document) | |||
| db.session.flush() | |||
| if save_child: | |||
| for postion, child in enumerate(doc.children, start=1): | |||
| child_segment = ChildChunk( | |||
| tenant_id=self._dataset.tenant_id, | |||
| dataset_id=self._dataset.id, | |||
| document_id=self._document_id, | |||
| segment_id=segment_document.id, | |||
| position=postion, | |||
| index_node_id=child.metadata["doc_id"], | |||
| index_node_hash=child.metadata["doc_hash"], | |||
| content=child.page_content, | |||
| word_count=len(child.page_content), | |||
| type="automatic", | |||
| created_by=self._user_id, | |||
| ) | |||
| db.session.add(child_segment) | |||
| else: | |||
| segment_document.content = doc.page_content | |||
| if doc.metadata.get("answer"): | |||
| @@ -127,6 +144,30 @@ class DatasetDocumentStore: | |||
| segment_document.index_node_hash = doc.metadata["doc_hash"] | |||
| segment_document.word_count = len(doc.page_content) | |||
| segment_document.tokens = tokens | |||
| if save_child and doc.children: | |||
| # delete the existing child chunks | |||
| db.session.query(ChildChunk).filter( | |||
| ChildChunk.tenant_id == self._dataset.tenant_id, | |||
| ChildChunk.dataset_id == self._dataset.id, | |||
| ChildChunk.document_id == self._document_id, | |||
| ChildChunk.segment_id == segment_document.id, | |||
| ).delete() | |||
| # add new child chunks | |||
| for position, child in enumerate(doc.children, start=1): | |||
| child_segment = ChildChunk( | |||
| tenant_id=self._dataset.tenant_id, | |||
| dataset_id=self._dataset.id, | |||
| document_id=self._document_id, | |||
| segment_id=segment_document.id, | |||
| position=position, | |||
| index_node_id=child.metadata["doc_id"], | |||
| index_node_hash=child.metadata["doc_hash"], | |||
| content=child.page_content, | |||
| word_count=len(child.page_content), | |||
| type="automatic", | |||
| created_by=self._user_id, | |||
| ) | |||
| db.session.add(child_segment) | |||
| db.session.commit() | |||
| @@ -0,0 +1,23 @@ | |||
| from typing import Optional | |||
| from pydantic import BaseModel | |||
| from models.dataset import DocumentSegment | |||
| class RetrievalChildChunk(BaseModel): | |||
| """Retrieval segments.""" | |||
| id: str | |||
| content: str | |||
| score: float | |||
| position: int | |||
| class RetrievalSegments(BaseModel): | |||
| """Retrieval segments.""" | |||
| model_config = {"arbitrary_types_allowed": True} | |||
| segment: DocumentSegment | |||
| child_chunks: Optional[list[RetrievalChildChunk]] = None | |||
| score: Optional[float] = None | |||
| @@ -24,7 +24,6 @@ from core.rag.extractor.unstructured.unstructured_markdown_extractor import Unst | |||
| from core.rag.extractor.unstructured.unstructured_msg_extractor import UnstructuredMsgExtractor | |||
| from core.rag.extractor.unstructured.unstructured_ppt_extractor import UnstructuredPPTExtractor | |||
| from core.rag.extractor.unstructured.unstructured_pptx_extractor import UnstructuredPPTXExtractor | |||
| from core.rag.extractor.unstructured.unstructured_text_extractor import UnstructuredTextExtractor | |||
| from core.rag.extractor.unstructured.unstructured_xml_extractor import UnstructuredXmlExtractor | |||
| from core.rag.extractor.word_extractor import WordExtractor | |||
| from core.rag.models.document import Document | |||
| @@ -141,11 +140,7 @@ class ExtractProcessor: | |||
| extractor = UnstructuredEpubExtractor(file_path, unstructured_api_url, unstructured_api_key) | |||
| else: | |||
| # txt | |||
| extractor = ( | |||
| UnstructuredTextExtractor(file_path, unstructured_api_url) | |||
| if is_automatic | |||
| else TextExtractor(file_path, autodetect_encoding=True) | |||
| ) | |||
| extractor = TextExtractor(file_path, autodetect_encoding=True) | |||
| else: | |||
| if file_extension in {".xlsx", ".xls"}: | |||
| extractor = ExcelExtractor(file_path) | |||
| @@ -267,8 +267,10 @@ class WordExtractor(BaseExtractor): | |||
| if isinstance(element.tag, str) and element.tag.endswith("p"): # paragraph | |||
| para = paragraphs.pop(0) | |||
| parsed_paragraph = parse_paragraph(para) | |||
| if parsed_paragraph: | |||
| if parsed_paragraph.strip(): | |||
| content.append(parsed_paragraph) | |||
| else: | |||
| content.append("\n") | |||
| elif isinstance(element.tag, str) and element.tag.endswith("tbl"): # table | |||
| table = tables.pop(0) | |||
| content.append(self._table_to_markdown(table, image_map)) | |||
| @@ -1,8 +1,7 @@ | |||
| from enum import Enum | |||
| class IndexType(Enum): | |||
| class IndexType(str, Enum): | |||
| PARAGRAPH_INDEX = "text_model" | |||
| QA_INDEX = "qa_model" | |||
| PARENT_CHILD_INDEX = "parent_child_index" | |||
| SUMMARY_INDEX = "summary_index" | |||
| PARENT_CHILD_INDEX = "hierarchical_model" | |||
| @@ -27,10 +27,10 @@ class BaseIndexProcessor(ABC): | |||
| raise NotImplementedError | |||
| @abstractmethod | |||
| def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True): | |||
| def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True, **kwargs): | |||
| raise NotImplementedError | |||
| def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True): | |||
| def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True, **kwargs): | |||
| raise NotImplementedError | |||
| @abstractmethod | |||
| @@ -45,26 +45,29 @@ class BaseIndexProcessor(ABC): | |||
| ) -> list[Document]: | |||
| raise NotImplementedError | |||
| def _get_splitter(self, processing_rule: dict, embedding_model_instance: Optional[ModelInstance]) -> TextSplitter: | |||
| def _get_splitter( | |||
| self, | |||
| processing_rule_mode: str, | |||
| max_tokens: int, | |||
| chunk_overlap: int, | |||
| separator: str, | |||
| embedding_model_instance: Optional[ModelInstance], | |||
| ) -> TextSplitter: | |||
| """ | |||
| Get the NodeParser object according to the processing rule. | |||
| """ | |||
| character_splitter: TextSplitter | |||
| if processing_rule["mode"] == "custom": | |||
| if processing_rule_mode in ["custom", "hierarchical"]: | |||
| # The user-defined segmentation rule | |||
| rules = processing_rule["rules"] | |||
| segmentation = rules["segmentation"] | |||
| max_segmentation_tokens_length = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH | |||
| if segmentation["max_tokens"] < 50 or segmentation["max_tokens"] > max_segmentation_tokens_length: | |||
| if max_tokens < 50 or max_tokens > max_segmentation_tokens_length: | |||
| raise ValueError(f"Custom segment length should be between 50 and {max_segmentation_tokens_length}.") | |||
| separator = segmentation["separator"] | |||
| if separator: | |||
| separator = separator.replace("\\n", "\n") | |||
| character_splitter = FixedRecursiveCharacterTextSplitter.from_encoder( | |||
| chunk_size=segmentation["max_tokens"], | |||
| chunk_overlap=segmentation.get("chunk_overlap", 0) or 0, | |||
| chunk_size=max_tokens, | |||
| chunk_overlap=chunk_overlap, | |||
| fixed_separator=separator, | |||
| separators=["\n\n", "。", ". ", " ", ""], | |||
| embedding_model_instance=embedding_model_instance, | |||
| @@ -3,6 +3,7 @@ | |||
| from core.rag.index_processor.constant.index_type import IndexType | |||
| from core.rag.index_processor.index_processor_base import BaseIndexProcessor | |||
| from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor | |||
| from core.rag.index_processor.processor.parent_child_index_processor import ParentChildIndexProcessor | |||
| from core.rag.index_processor.processor.qa_index_processor import QAIndexProcessor | |||
| @@ -18,9 +19,11 @@ class IndexProcessorFactory: | |||
| if not self._index_type: | |||
| raise ValueError("Index type must be specified.") | |||
| if self._index_type == IndexType.PARAGRAPH_INDEX.value: | |||
| if self._index_type == IndexType.PARAGRAPH_INDEX: | |||
| return ParagraphIndexProcessor() | |||
| elif self._index_type == IndexType.QA_INDEX.value: | |||
| elif self._index_type == IndexType.QA_INDEX: | |||
| return QAIndexProcessor() | |||
| elif self._index_type == IndexType.PARENT_CHILD_INDEX: | |||
| return ParentChildIndexProcessor() | |||
| else: | |||
| raise ValueError(f"Index type {self._index_type} is not supported.") | |||
| @@ -13,21 +13,34 @@ from core.rag.index_processor.index_processor_base import BaseIndexProcessor | |||
| from core.rag.models.document import Document | |||
| from core.tools.utils.text_processing_utils import remove_leading_symbols | |||
| from libs import helper | |||
| from models.dataset import Dataset | |||
| from models.dataset import Dataset, DatasetProcessRule | |||
| from services.entities.knowledge_entities.knowledge_entities import Rule | |||
| class ParagraphIndexProcessor(BaseIndexProcessor): | |||
| def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]: | |||
| text_docs = ExtractProcessor.extract( | |||
| extract_setting=extract_setting, is_automatic=kwargs.get("process_rule_mode") == "automatic" | |||
| extract_setting=extract_setting, | |||
| is_automatic=( | |||
| kwargs.get("process_rule_mode") == "automatic" or kwargs.get("process_rule_mode") == "hierarchical" | |||
| ), | |||
| ) | |||
| return text_docs | |||
| def transform(self, documents: list[Document], **kwargs) -> list[Document]: | |||
| process_rule = kwargs.get("process_rule") | |||
| if process_rule.get("mode") == "automatic": | |||
| automatic_rule = DatasetProcessRule.AUTOMATIC_RULES | |||
| rules = Rule(**automatic_rule) | |||
| else: | |||
| rules = Rule(**process_rule.get("rules")) | |||
| # Split the text documents into nodes. | |||
| splitter = self._get_splitter( | |||
| processing_rule=kwargs.get("process_rule", {}), | |||
| processing_rule_mode=process_rule.get("mode"), | |||
| max_tokens=rules.segmentation.max_tokens, | |||
| chunk_overlap=rules.segmentation.chunk_overlap, | |||
| separator=rules.segmentation.separator, | |||
| embedding_model_instance=kwargs.get("embedding_model_instance"), | |||
| ) | |||
| all_documents = [] | |||
| @@ -53,15 +66,19 @@ class ParagraphIndexProcessor(BaseIndexProcessor): | |||
| all_documents.extend(split_documents) | |||
| return all_documents | |||
| def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True): | |||
| def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True, **kwargs): | |||
| if dataset.indexing_technique == "high_quality": | |||
| vector = Vector(dataset) | |||
| vector.create(documents) | |||
| if with_keywords: | |||
| keywords_list = kwargs.get("keywords_list") | |||
| keyword = Keyword(dataset) | |||
| keyword.create(documents) | |||
| if keywords_list and len(keywords_list) > 0: | |||
| keyword.add_texts(documents, keywords_list=keywords_list) | |||
| else: | |||
| keyword.add_texts(documents) | |||
| def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True): | |||
| def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True, **kwargs): | |||
| if dataset.indexing_technique == "high_quality": | |||
| vector = Vector(dataset) | |||
| if node_ids: | |||
| @@ -0,0 +1,189 @@ | |||
| """Paragraph index processor.""" | |||
| import uuid | |||
| from typing import Optional | |||
| from core.model_manager import ModelInstance | |||
| from core.rag.cleaner.clean_processor import CleanProcessor | |||
| from core.rag.datasource.retrieval_service import RetrievalService | |||
| from core.rag.datasource.vdb.vector_factory import Vector | |||
| from core.rag.extractor.entity.extract_setting import ExtractSetting | |||
| from core.rag.extractor.extract_processor import ExtractProcessor | |||
| from core.rag.index_processor.index_processor_base import BaseIndexProcessor | |||
| from core.rag.models.document import ChildDocument, Document | |||
| from extensions.ext_database import db | |||
| from libs import helper | |||
| from models.dataset import ChildChunk, Dataset, DocumentSegment | |||
| from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule | |||
| class ParentChildIndexProcessor(BaseIndexProcessor): | |||
| def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]: | |||
| text_docs = ExtractProcessor.extract( | |||
| extract_setting=extract_setting, | |||
| is_automatic=( | |||
| kwargs.get("process_rule_mode") == "automatic" or kwargs.get("process_rule_mode") == "hierarchical" | |||
| ), | |||
| ) | |||
| return text_docs | |||
| def transform(self, documents: list[Document], **kwargs) -> list[Document]: | |||
| process_rule = kwargs.get("process_rule") | |||
| rules = Rule(**process_rule.get("rules")) | |||
| all_documents = [] | |||
| if rules.parent_mode == ParentMode.PARAGRAPH: | |||
| # Split the text documents into nodes. | |||
| splitter = self._get_splitter( | |||
| processing_rule_mode=process_rule.get("mode"), | |||
| max_tokens=rules.segmentation.max_tokens, | |||
| chunk_overlap=rules.segmentation.chunk_overlap, | |||
| separator=rules.segmentation.separator, | |||
| embedding_model_instance=kwargs.get("embedding_model_instance"), | |||
| ) | |||
| for document in documents: | |||
| # document clean | |||
| document_text = CleanProcessor.clean(document.page_content, process_rule) | |||
| document.page_content = document_text | |||
| # parse document to nodes | |||
| document_nodes = splitter.split_documents([document]) | |||
| split_documents = [] | |||
| for document_node in document_nodes: | |||
| if document_node.page_content.strip(): | |||
| doc_id = str(uuid.uuid4()) | |||
| hash = helper.generate_text_hash(document_node.page_content) | |||
| document_node.metadata["doc_id"] = doc_id | |||
| document_node.metadata["doc_hash"] = hash | |||
| # delete Splitter character | |||
| page_content = document_node.page_content | |||
| if page_content.startswith(".") or page_content.startswith("。"): | |||
| page_content = page_content[1:].strip() | |||
| else: | |||
| page_content = page_content | |||
| if len(page_content) > 0: | |||
| document_node.page_content = page_content | |||
| # parse document to child nodes | |||
| child_nodes = self._split_child_nodes( | |||
| document_node, rules, process_rule.get("mode"), kwargs.get("embedding_model_instance") | |||
| ) | |||
| document_node.children = child_nodes | |||
| split_documents.append(document_node) | |||
| all_documents.extend(split_documents) | |||
| elif rules.parent_mode == ParentMode.FULL_DOC: | |||
| page_content = "\n".join([document.page_content for document in documents]) | |||
| document = Document(page_content=page_content, metadata=documents[0].metadata) | |||
| # parse document to child nodes | |||
| child_nodes = self._split_child_nodes( | |||
| document, rules, process_rule.get("mode"), kwargs.get("embedding_model_instance") | |||
| ) | |||
| document.children = child_nodes | |||
| doc_id = str(uuid.uuid4()) | |||
| hash = helper.generate_text_hash(document.page_content) | |||
| document.metadata["doc_id"] = doc_id | |||
| document.metadata["doc_hash"] = hash | |||
| all_documents.append(document) | |||
| return all_documents | |||
| def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True, **kwargs): | |||
| if dataset.indexing_technique == "high_quality": | |||
| vector = Vector(dataset) | |||
| for document in documents: | |||
| child_documents = document.children | |||
| if child_documents: | |||
| formatted_child_documents = [ | |||
| Document(**child_document.model_dump()) for child_document in child_documents | |||
| ] | |||
| vector.create(formatted_child_documents) | |||
| def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True, **kwargs): | |||
| # node_ids is segment's node_ids | |||
| if dataset.indexing_technique == "high_quality": | |||
| delete_child_chunks = kwargs.get("delete_child_chunks") or False | |||
| vector = Vector(dataset) | |||
| if node_ids: | |||
| child_node_ids = ( | |||
| db.session.query(ChildChunk.index_node_id) | |||
| .join(DocumentSegment, ChildChunk.segment_id == DocumentSegment.id) | |||
| .filter( | |||
| DocumentSegment.dataset_id == dataset.id, | |||
| DocumentSegment.index_node_id.in_(node_ids), | |||
| ChildChunk.dataset_id == dataset.id, | |||
| ) | |||
| .all() | |||
| ) | |||
| child_node_ids = [child_node_id[0] for child_node_id in child_node_ids] | |||
| vector.delete_by_ids(child_node_ids) | |||
| if delete_child_chunks: | |||
| db.session.query(ChildChunk).filter( | |||
| ChildChunk.dataset_id == dataset.id, ChildChunk.index_node_id.in_(child_node_ids) | |||
| ).delete() | |||
| db.session.commit() | |||
| else: | |||
| vector.delete() | |||
| if delete_child_chunks: | |||
| db.session.query(ChildChunk).filter(ChildChunk.dataset_id == dataset.id).delete() | |||
| db.session.commit() | |||
| def retrieve( | |||
| self, | |||
| retrieval_method: str, | |||
| query: str, | |||
| dataset: Dataset, | |||
| top_k: int, | |||
| score_threshold: float, | |||
| reranking_model: dict, | |||
| ) -> list[Document]: | |||
| # Set search parameters. | |||
| results = RetrievalService.retrieve( | |||
| retrieval_method=retrieval_method, | |||
| dataset_id=dataset.id, | |||
| query=query, | |||
| top_k=top_k, | |||
| score_threshold=score_threshold, | |||
| reranking_model=reranking_model, | |||
| ) | |||
| # Organize results. | |||
| docs = [] | |||
| for result in results: | |||
| metadata = result.metadata | |||
| metadata["score"] = result.score | |||
| if result.score > score_threshold: | |||
| doc = Document(page_content=result.page_content, metadata=metadata) | |||
| docs.append(doc) | |||
| return docs | |||
| def _split_child_nodes( | |||
| self, | |||
| document_node: Document, | |||
| rules: Rule, | |||
| process_rule_mode: str, | |||
| embedding_model_instance: Optional[ModelInstance], | |||
| ) -> list[ChildDocument]: | |||
| child_splitter = self._get_splitter( | |||
| processing_rule_mode=process_rule_mode, | |||
| max_tokens=rules.subchunk_segmentation.max_tokens, | |||
| chunk_overlap=rules.subchunk_segmentation.chunk_overlap, | |||
| separator=rules.subchunk_segmentation.separator, | |||
| embedding_model_instance=embedding_model_instance, | |||
| ) | |||
| # parse document to child nodes | |||
| child_nodes = [] | |||
| child_documents = child_splitter.split_documents([document_node]) | |||
| for child_document_node in child_documents: | |||
| if child_document_node.page_content.strip(): | |||
| doc_id = str(uuid.uuid4()) | |||
| hash = helper.generate_text_hash(child_document_node.page_content) | |||
| child_document = ChildDocument( | |||
| page_content=child_document_node.page_content, metadata=document_node.metadata | |||
| ) | |||
| child_document.metadata["doc_id"] = doc_id | |||
| child_document.metadata["doc_hash"] = hash | |||
| child_page_content = child_document.page_content | |||
| if child_page_content.startswith(".") or child_page_content.startswith("。"): | |||
| child_page_content = child_page_content[1:].strip() | |||
| if len(child_page_content) > 0: | |||
| child_document.page_content = child_page_content | |||
| child_nodes.append(child_document) | |||
| return child_nodes | |||
| @@ -21,18 +21,28 @@ from core.rag.models.document import Document | |||
| from core.tools.utils.text_processing_utils import remove_leading_symbols | |||
| from libs import helper | |||
| from models.dataset import Dataset | |||
| from services.entities.knowledge_entities.knowledge_entities import Rule | |||
| class QAIndexProcessor(BaseIndexProcessor): | |||
| def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]: | |||
| text_docs = ExtractProcessor.extract( | |||
| extract_setting=extract_setting, is_automatic=kwargs.get("process_rule_mode") == "automatic" | |||
| extract_setting=extract_setting, | |||
| is_automatic=( | |||
| kwargs.get("process_rule_mode") == "automatic" or kwargs.get("process_rule_mode") == "hierarchical" | |||
| ), | |||
| ) | |||
| return text_docs | |||
| def transform(self, documents: list[Document], **kwargs) -> list[Document]: | |||
| preview = kwargs.get("preview") | |||
| process_rule = kwargs.get("process_rule") | |||
| rules = Rule(**process_rule.get("rules")) | |||
| splitter = self._get_splitter( | |||
| processing_rule=kwargs.get("process_rule") or {}, | |||
| processing_rule_mode=process_rule.get("mode"), | |||
| max_tokens=rules.segmentation.max_tokens, | |||
| chunk_overlap=rules.segmentation.chunk_overlap, | |||
| separator=rules.segmentation.separator, | |||
| embedding_model_instance=kwargs.get("embedding_model_instance"), | |||
| ) | |||
| @@ -59,24 +69,33 @@ class QAIndexProcessor(BaseIndexProcessor): | |||
| document_node.page_content = remove_leading_symbols(page_content) | |||
| split_documents.append(document_node) | |||
| all_documents.extend(split_documents) | |||
| for i in range(0, len(all_documents), 10): | |||
| threads = [] | |||
| sub_documents = all_documents[i : i + 10] | |||
| for doc in sub_documents: | |||
| document_format_thread = threading.Thread( | |||
| target=self._format_qa_document, | |||
| kwargs={ | |||
| "flask_app": current_app._get_current_object(), # type: ignore | |||
| "tenant_id": kwargs.get("tenant_id"), | |||
| "document_node": doc, | |||
| "all_qa_documents": all_qa_documents, | |||
| "document_language": kwargs.get("doc_language", "English"), | |||
| }, | |||
| ) | |||
| threads.append(document_format_thread) | |||
| document_format_thread.start() | |||
| for thread in threads: | |||
| thread.join() | |||
| if preview: | |||
| self._format_qa_document( | |||
| current_app._get_current_object(), | |||
| kwargs.get("tenant_id"), | |||
| all_documents[0], | |||
| all_qa_documents, | |||
| kwargs.get("doc_language", "English"), | |||
| ) | |||
| else: | |||
| for i in range(0, len(all_documents), 10): | |||
| threads = [] | |||
| sub_documents = all_documents[i : i + 10] | |||
| for doc in sub_documents: | |||
| document_format_thread = threading.Thread( | |||
| target=self._format_qa_document, | |||
| kwargs={ | |||
| "flask_app": current_app._get_current_object(), | |||
| "tenant_id": kwargs.get("tenant_id"), | |||
| "document_node": doc, | |||
| "all_qa_documents": all_qa_documents, | |||
| "document_language": kwargs.get("doc_language", "English"), | |||
| }, | |||
| ) | |||
| threads.append(document_format_thread) | |||
| document_format_thread.start() | |||
| for thread in threads: | |||
| thread.join() | |||
| return all_qa_documents | |||
| def format_by_template(self, file: FileStorage, **kwargs) -> list[Document]: | |||
| @@ -98,12 +117,12 @@ class QAIndexProcessor(BaseIndexProcessor): | |||
| raise ValueError(str(e)) | |||
| return text_docs | |||
| def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True): | |||
| def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True, **kwargs): | |||
| if dataset.indexing_technique == "high_quality": | |||
| vector = Vector(dataset) | |||
| vector.create(documents) | |||
| def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True): | |||
| def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True, **kwargs): | |||
| vector = Vector(dataset) | |||
| if node_ids: | |||
| vector.delete_by_ids(node_ids) | |||
| @@ -5,6 +5,19 @@ from typing import Any, Optional | |||
| from pydantic import BaseModel, Field | |||
| class ChildDocument(BaseModel): | |||
| """Class for storing a piece of text and associated metadata.""" | |||
| page_content: str | |||
| vector: Optional[list[float]] = None | |||
| """Arbitrary metadata about the page content (e.g., source, relationships to other | |||
| documents, etc.). | |||
| """ | |||
| metadata: Optional[dict] = Field(default_factory=dict) | |||
| class Document(BaseModel): | |||
| """Class for storing a piece of text and associated metadata.""" | |||
| @@ -19,6 +32,8 @@ class Document(BaseModel): | |||
| provider: Optional[str] = "dify" | |||
| children: Optional[list[ChildDocument]] = None | |||
| class BaseDocumentTransformer(ABC): | |||
| """Abstract base class for document transformation systems. | |||
| @@ -166,43 +166,29 @@ class DatasetRetrieval: | |||
| "content": item.page_content, | |||
| } | |||
| retrieval_resource_list.append(source) | |||
| document_score_list = {} | |||
| # deal with dify documents | |||
| if dify_documents: | |||
| for item in dify_documents: | |||
| if item.metadata.get("score"): | |||
| document_score_list[item.metadata["doc_id"]] = item.metadata["score"] | |||
| index_node_ids = [document.metadata["doc_id"] for document in dify_documents] | |||
| segments = DocumentSegment.query.filter( | |||
| DocumentSegment.dataset_id.in_(dataset_ids), | |||
| DocumentSegment.status == "completed", | |||
| DocumentSegment.enabled == True, | |||
| DocumentSegment.index_node_id.in_(index_node_ids), | |||
| ).all() | |||
| if segments: | |||
| index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)} | |||
| sorted_segments = sorted( | |||
| segments, key=lambda segment: index_node_id_to_position.get(segment.index_node_id, float("inf")) | |||
| ) | |||
| for segment in sorted_segments: | |||
| records = RetrievalService.format_retrieval_documents(dify_documents) | |||
| if records: | |||
| for record in records: | |||
| segment = record.segment | |||
| if segment.answer: | |||
| document_context_list.append( | |||
| DocumentContext( | |||
| content=f"question:{segment.get_sign_content()} answer:{segment.answer}", | |||
| score=document_score_list.get(segment.index_node_id, None), | |||
| score=record.score, | |||
| ) | |||
| ) | |||
| else: | |||
| document_context_list.append( | |||
| DocumentContext( | |||
| content=segment.get_sign_content(), | |||
| score=document_score_list.get(segment.index_node_id, None), | |||
| score=record.score, | |||
| ) | |||
| ) | |||
| if show_retrieve_source: | |||
| for segment in sorted_segments: | |||
| for record in records: | |||
| segment = record.segment | |||
| dataset = Dataset.query.filter_by(id=segment.dataset_id).first() | |||
| document = DatasetDocument.query.filter( | |||
| DatasetDocument.id == segment.document_id, | |||
| @@ -218,7 +204,7 @@ class DatasetRetrieval: | |||
| "data_source_type": document.data_source_type, | |||
| "segment_id": segment.id, | |||
| "retriever_from": invoke_from.to_source(), | |||
| "score": document_score_list.get(segment.index_node_id, 0.0), | |||
| "score": record.score or 0.0, | |||
| } | |||
| if invoke_from.to_source() == "dev": | |||
| @@ -11,6 +11,7 @@ from core.entities.model_entities import ModelStatus | |||
| from core.model_manager import ModelInstance, ModelManager | |||
| from core.model_runtime.entities.model_entities import ModelFeature, ModelType | |||
| from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel | |||
| from core.rag.datasource.retrieval_service import RetrievalService | |||
| from core.rag.retrieval.dataset_retrieval import DatasetRetrieval | |||
| from core.rag.retrieval.retrieval_methods import RetrievalMethod | |||
| from core.variables import StringSegment | |||
| @@ -18,7 +19,7 @@ from core.workflow.entities.node_entities import NodeRunResult | |||
| from core.workflow.nodes.base import BaseNode | |||
| from core.workflow.nodes.enums import NodeType | |||
| from extensions.ext_database import db | |||
| from models.dataset import Dataset, Document, DocumentSegment | |||
| from models.dataset import Dataset, Document | |||
| from models.workflow import WorkflowNodeExecutionStatus | |||
| from .entities import KnowledgeRetrievalNodeData | |||
| @@ -211,29 +212,12 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]): | |||
| "content": item.page_content, | |||
| } | |||
| retrieval_resource_list.append(source) | |||
| document_score_list: dict[str, float] = {} | |||
| # deal with dify documents | |||
| if dify_documents: | |||
| document_score_list = {} | |||
| for item in dify_documents: | |||
| if item.metadata.get("score"): | |||
| document_score_list[item.metadata["doc_id"]] = item.metadata["score"] | |||
| index_node_ids = [document.metadata["doc_id"] for document in dify_documents] | |||
| segments = DocumentSegment.query.filter( | |||
| DocumentSegment.dataset_id.in_(dataset_ids), | |||
| DocumentSegment.completed_at.isnot(None), | |||
| DocumentSegment.status == "completed", | |||
| DocumentSegment.enabled == True, | |||
| DocumentSegment.index_node_id.in_(index_node_ids), | |||
| ).all() | |||
| if segments: | |||
| index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)} | |||
| sorted_segments = sorted( | |||
| segments, key=lambda segment: index_node_id_to_position.get(segment.index_node_id, float("inf")) | |||
| ) | |||
| for segment in sorted_segments: | |||
| records = RetrievalService.format_retrieval_documents(dify_documents) | |||
| if records: | |||
| for record in records: | |||
| segment = record.segment | |||
| dataset = Dataset.query.filter_by(id=segment.dataset_id).first() | |||
| document = Document.query.filter( | |||
| Document.id == segment.document_id, | |||
| @@ -251,7 +235,7 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]): | |||
| "document_data_source_type": document.data_source_type, | |||
| "segment_id": segment.id, | |||
| "retriever_from": "workflow", | |||
| "score": document_score_list.get(segment.index_node_id, None), | |||
| "score": record.score or 0.0, | |||
| "segment_hit_count": segment.hit_count, | |||
| "segment_word_count": segment.word_count, | |||
| "segment_position": segment.position, | |||
| @@ -270,10 +254,8 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]): | |||
| key=lambda x: x["metadata"]["score"] if x["metadata"].get("score") is not None else 0.0, | |||
| reverse=True, | |||
| ) | |||
| position = 1 | |||
| for item in retrieval_resource_list: | |||
| for position, item in enumerate(retrieval_resource_list, start=1): | |||
| item["metadata"]["position"] = position | |||
| position += 1 | |||
| return retrieval_resource_list | |||
| @classmethod | |||
| @@ -73,6 +73,7 @@ dataset_detail_fields = { | |||
| "embedding_available": fields.Boolean, | |||
| "retrieval_model_dict": fields.Nested(dataset_retrieval_model_fields), | |||
| "tags": fields.List(fields.Nested(tag_fields)), | |||
| "doc_form": fields.String, | |||
| "external_knowledge_info": fields.Nested(external_knowledge_info_fields), | |||
| "external_retrieval_model": fields.Nested(external_retrieval_model_fields, allow_null=True), | |||
| } | |||
| @@ -34,6 +34,7 @@ document_with_segments_fields = { | |||
| "data_source_info": fields.Raw(attribute="data_source_info_dict"), | |||
| "data_source_detail_dict": fields.Raw(attribute="data_source_detail_dict"), | |||
| "dataset_process_rule_id": fields.String, | |||
| "process_rule_dict": fields.Raw(attribute="process_rule_dict"), | |||
| "name": fields.String, | |||
| "created_from": fields.String, | |||
| "created_by": fields.String, | |||
| @@ -34,8 +34,16 @@ segment_fields = { | |||
| "document": fields.Nested(document_fields), | |||
| } | |||
| child_chunk_fields = { | |||
| "id": fields.String, | |||
| "content": fields.String, | |||
| "position": fields.Integer, | |||
| "score": fields.Float, | |||
| } | |||
| hit_testing_record_fields = { | |||
| "segment": fields.Nested(segment_fields), | |||
| "child_chunks": fields.List(fields.Nested(child_chunk_fields)), | |||
| "score": fields.Float, | |||
| "tsne_position": fields.Raw, | |||
| } | |||
| @@ -2,6 +2,17 @@ from flask_restful import fields # type: ignore | |||
| from libs.helper import TimestampField | |||
| child_chunk_fields = { | |||
| "id": fields.String, | |||
| "segment_id": fields.String, | |||
| "content": fields.String, | |||
| "position": fields.Integer, | |||
| "word_count": fields.Integer, | |||
| "type": fields.String, | |||
| "created_at": TimestampField, | |||
| "updated_at": TimestampField, | |||
| } | |||
| segment_fields = { | |||
| "id": fields.String, | |||
| "position": fields.Integer, | |||
| @@ -20,10 +31,13 @@ segment_fields = { | |||
| "status": fields.String, | |||
| "created_by": fields.String, | |||
| "created_at": TimestampField, | |||
| "updated_at": TimestampField, | |||
| "updated_by": fields.String, | |||
| "indexing_at": TimestampField, | |||
| "completed_at": TimestampField, | |||
| "error": fields.String, | |||
| "stopped_at": TimestampField, | |||
| "child_chunks": fields.List(fields.Nested(child_chunk_fields)), | |||
| } | |||
| segment_list_response = { | |||
| @@ -0,0 +1,55 @@ | |||
| """parent-child-index | |||
| Revision ID: e19037032219 | |||
| Revises: 01d6889832f7 | |||
| Create Date: 2024-11-22 07:01:17.550037 | |||
| """ | |||
| from alembic import op | |||
| import models as models | |||
| import sqlalchemy as sa | |||
| # revision identifiers, used by Alembic. | |||
| revision = 'e19037032219' | |||
| down_revision = 'd7999dfa4aae' | |||
| branch_labels = None | |||
| depends_on = None | |||
| def upgrade(): | |||
| # ### commands auto generated by Alembic - please adjust! ### | |||
| op.create_table('child_chunks', | |||
| sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), | |||
| sa.Column('tenant_id', models.types.StringUUID(), nullable=False), | |||
| sa.Column('dataset_id', models.types.StringUUID(), nullable=False), | |||
| sa.Column('document_id', models.types.StringUUID(), nullable=False), | |||
| sa.Column('segment_id', models.types.StringUUID(), nullable=False), | |||
| sa.Column('position', sa.Integer(), nullable=False), | |||
| sa.Column('content', sa.Text(), nullable=False), | |||
| sa.Column('word_count', sa.Integer(), nullable=False), | |||
| sa.Column('index_node_id', sa.String(length=255), nullable=True), | |||
| sa.Column('index_node_hash', sa.String(length=255), nullable=True), | |||
| sa.Column('type', sa.String(length=255), server_default=sa.text("'automatic'::character varying"), nullable=False), | |||
| sa.Column('created_by', models.types.StringUUID(), nullable=False), | |||
| sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), | |||
| sa.Column('updated_by', models.types.StringUUID(), nullable=True), | |||
| sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), | |||
| sa.Column('indexing_at', sa.DateTime(), nullable=True), | |||
| sa.Column('completed_at', sa.DateTime(), nullable=True), | |||
| sa.Column('error', sa.Text(), nullable=True), | |||
| sa.PrimaryKeyConstraint('id', name='child_chunk_pkey') | |||
| ) | |||
| with op.batch_alter_table('child_chunks', schema=None) as batch_op: | |||
| batch_op.create_index('child_chunk_dataset_id_idx', ['tenant_id', 'dataset_id', 'document_id', 'segment_id', 'index_node_id'], unique=False) | |||
| # ### end Alembic commands ### | |||
| def downgrade(): | |||
| # ### commands auto generated by Alembic - please adjust! ### | |||
| with op.batch_alter_table('child_chunks', schema=None) as batch_op: | |||
| batch_op.drop_index('child_chunk_dataset_id_idx') | |||
| op.drop_table('child_chunks') | |||
| # ### end Alembic commands ### | |||
| @@ -0,0 +1,47 @@ | |||
| """add_auto_disabled_dataset_logs | |||
| Revision ID: 923752d42eb6 | |||
| Revises: e19037032219 | |||
| Create Date: 2024-12-25 11:37:55.467101 | |||
| """ | |||
| from alembic import op | |||
| import models as models | |||
| import sqlalchemy as sa | |||
| from sqlalchemy.dialects import postgresql | |||
| # revision identifiers, used by Alembic. | |||
| revision = '923752d42eb6' | |||
| down_revision = 'e19037032219' | |||
| branch_labels = None | |||
| depends_on = None | |||
| def upgrade(): | |||
| # ### commands auto generated by Alembic - please adjust! ### | |||
| op.create_table('dataset_auto_disable_logs', | |||
| sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), | |||
| sa.Column('tenant_id', models.types.StringUUID(), nullable=False), | |||
| sa.Column('dataset_id', models.types.StringUUID(), nullable=False), | |||
| sa.Column('document_id', models.types.StringUUID(), nullable=False), | |||
| sa.Column('notified', sa.Boolean(), server_default=sa.text('false'), nullable=False), | |||
| sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), | |||
| sa.PrimaryKeyConstraint('id', name='dataset_auto_disable_log_pkey') | |||
| ) | |||
| with op.batch_alter_table('dataset_auto_disable_logs', schema=None) as batch_op: | |||
| batch_op.create_index('dataset_auto_disable_log_created_atx', ['created_at'], unique=False) | |||
| batch_op.create_index('dataset_auto_disable_log_dataset_idx', ['dataset_id'], unique=False) | |||
| batch_op.create_index('dataset_auto_disable_log_tenant_idx', ['tenant_id'], unique=False) | |||
| # ### end Alembic commands ### | |||
| def downgrade(): | |||
| # ### commands auto generated by Alembic - please adjust! ### | |||
| with op.batch_alter_table('dataset_auto_disable_logs', schema=None) as batch_op: | |||
| batch_op.drop_index('dataset_auto_disable_log_tenant_idx') | |||
| batch_op.drop_index('dataset_auto_disable_log_dataset_idx') | |||
| batch_op.drop_index('dataset_auto_disable_log_created_atx') | |||
| op.drop_table('dataset_auto_disable_logs') | |||
| # ### end Alembic commands ### | |||
| @@ -17,6 +17,7 @@ from sqlalchemy.dialects.postgresql import JSONB | |||
| from configs import dify_config | |||
| from core.rag.retrieval.retrieval_methods import RetrievalMethod | |||
| from extensions.ext_storage import storage | |||
| from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule | |||
| from .account import Account | |||
| from .engine import db | |||
| @@ -215,7 +216,7 @@ class DatasetProcessRule(db.Model): # type: ignore[name-defined] | |||
| created_by = db.Column(StringUUID, nullable=False) | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | |||
| MODES = ["automatic", "custom"] | |||
| MODES = ["automatic", "custom", "hierarchical"] | |||
| PRE_PROCESSING_RULES = ["remove_stopwords", "remove_extra_spaces", "remove_urls_emails"] | |||
| AUTOMATIC_RULES: dict[str, Any] = { | |||
| "pre_processing_rules": [ | |||
| @@ -231,8 +232,6 @@ class DatasetProcessRule(db.Model): # type: ignore[name-defined] | |||
| "dataset_id": self.dataset_id, | |||
| "mode": self.mode, | |||
| "rules": self.rules_dict, | |||
| "created_by": self.created_by, | |||
| "created_at": self.created_at, | |||
| } | |||
| @property | |||
| @@ -396,6 +395,12 @@ class Document(db.Model): # type: ignore[name-defined] | |||
| .scalar() | |||
| ) | |||
| @property | |||
| def process_rule_dict(self): | |||
| if self.dataset_process_rule_id: | |||
| return self.dataset_process_rule.to_dict() | |||
| return None | |||
| def to_dict(self): | |||
| return { | |||
| "id": self.id, | |||
| @@ -560,6 +565,24 @@ class DocumentSegment(db.Model): # type: ignore[name-defined] | |||
| .first() | |||
| ) | |||
| @property | |||
| def child_chunks(self): | |||
| process_rule = self.document.dataset_process_rule | |||
| if process_rule.mode == "hierarchical": | |||
| rules = Rule(**process_rule.rules_dict) | |||
| if rules.parent_mode and rules.parent_mode != ParentMode.FULL_DOC: | |||
| child_chunks = ( | |||
| db.session.query(ChildChunk) | |||
| .filter(ChildChunk.segment_id == self.id) | |||
| .order_by(ChildChunk.position.asc()) | |||
| .all() | |||
| ) | |||
| return child_chunks or [] | |||
| else: | |||
| return [] | |||
| else: | |||
| return [] | |||
| def get_sign_content(self): | |||
| signed_urls = [] | |||
| text = self.content | |||
| @@ -605,6 +628,47 @@ class DocumentSegment(db.Model): # type: ignore[name-defined] | |||
| return text | |||
| class ChildChunk(db.Model): | |||
| __tablename__ = "child_chunks" | |||
| __table_args__ = ( | |||
| db.PrimaryKeyConstraint("id", name="child_chunk_pkey"), | |||
| db.Index("child_chunk_dataset_id_idx", "tenant_id", "dataset_id", "document_id", "segment_id", "index_node_id"), | |||
| ) | |||
| # initial fields | |||
| id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) | |||
| tenant_id = db.Column(StringUUID, nullable=False) | |||
| dataset_id = db.Column(StringUUID, nullable=False) | |||
| document_id = db.Column(StringUUID, nullable=False) | |||
| segment_id = db.Column(StringUUID, nullable=False) | |||
| position = db.Column(db.Integer, nullable=False) | |||
| content = db.Column(db.Text, nullable=False) | |||
| word_count = db.Column(db.Integer, nullable=False) | |||
| # indexing fields | |||
| index_node_id = db.Column(db.String(255), nullable=True) | |||
| index_node_hash = db.Column(db.String(255), nullable=True) | |||
| type = db.Column(db.String(255), nullable=False, server_default=db.text("'automatic'::character varying")) | |||
| created_by = db.Column(StringUUID, nullable=False) | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |||
| updated_by = db.Column(StringUUID, nullable=True) | |||
| updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |||
| indexing_at = db.Column(db.DateTime, nullable=True) | |||
| completed_at = db.Column(db.DateTime, nullable=True) | |||
| error = db.Column(db.Text, nullable=True) | |||
| @property | |||
| def dataset(self): | |||
| return db.session.query(Dataset).filter(Dataset.id == self.dataset_id).first() | |||
| @property | |||
| def document(self): | |||
| return db.session.query(Document).filter(Document.id == self.document_id).first() | |||
| @property | |||
| def segment(self): | |||
| return db.session.query(DocumentSegment).filter(DocumentSegment.id == self.segment_id).first() | |||
| class AppDatasetJoin(db.Model): # type: ignore[name-defined] | |||
| __tablename__ = "app_dataset_joins" | |||
| __table_args__ = ( | |||
| @@ -844,3 +908,20 @@ class ExternalKnowledgeBindings(db.Model): # type: ignore[name-defined] | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | |||
| updated_by = db.Column(StringUUID, nullable=True) | |||
| updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) | |||
| class DatasetAutoDisableLog(db.Model): | |||
| __tablename__ = "dataset_auto_disable_logs" | |||
| __table_args__ = ( | |||
| db.PrimaryKeyConstraint("id", name="dataset_auto_disable_log_pkey"), | |||
| db.Index("dataset_auto_disable_log_tenant_idx", "tenant_id"), | |||
| db.Index("dataset_auto_disable_log_dataset_idx", "dataset_id"), | |||
| db.Index("dataset_auto_disable_log_created_atx", "created_at"), | |||
| ) | |||
| id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) | |||
| tenant_id = db.Column(StringUUID, nullable=False) | |||
| dataset_id = db.Column(StringUUID, nullable=False) | |||
| document_id = db.Column(StringUUID, nullable=False) | |||
| notified = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) | |||
| created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |||
| @@ -10,7 +10,7 @@ from configs import dify_config | |||
| from core.rag.index_processor.index_processor_factory import IndexProcessorFactory | |||
| from extensions.ext_database import db | |||
| from extensions.ext_redis import redis_client | |||
| from models.dataset import Dataset, DatasetQuery, Document | |||
| from models.dataset import Dataset, DatasetAutoDisableLog, DatasetQuery, Document | |||
| from services.feature_service import FeatureService | |||
| @@ -75,6 +75,23 @@ def clean_unused_datasets_task(): | |||
| ) | |||
| if not dataset_query or len(dataset_query) == 0: | |||
| try: | |||
| # add auto disable log | |||
| documents = ( | |||
| db.session.query(Document) | |||
| .filter( | |||
| Document.dataset_id == dataset.id, | |||
| Document.enabled == True, | |||
| Document.archived == False, | |||
| ) | |||
| .all() | |||
| ) | |||
| for document in documents: | |||
| dataset_auto_disable_log = DatasetAutoDisableLog( | |||
| tenant_id=dataset.tenant_id, | |||
| dataset_id=dataset.id, | |||
| document_id=document.id, | |||
| ) | |||
| db.session.add(dataset_auto_disable_log) | |||
| # remove index | |||
| index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor() | |||
| index_processor.clean(dataset, None) | |||
| @@ -151,6 +168,23 @@ def clean_unused_datasets_task(): | |||
| else: | |||
| plan = plan_cache.decode() | |||
| if plan == "sandbox": | |||
| # add auto disable log | |||
| documents = ( | |||
| db.session.query(Document) | |||
| .filter( | |||
| Document.dataset_id == dataset.id, | |||
| Document.enabled == True, | |||
| Document.archived == False, | |||
| ) | |||
| .all() | |||
| ) | |||
| for document in documents: | |||
| dataset_auto_disable_log = DatasetAutoDisableLog( | |||
| tenant_id=dataset.tenant_id, | |||
| dataset_id=dataset.id, | |||
| document_id=document.id, | |||
| ) | |||
| db.session.add(dataset_auto_disable_log) | |||
| # remove index | |||
| index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor() | |||
| index_processor.clean(dataset, None) | |||
| @@ -0,0 +1,66 @@ | |||
| import logging | |||
| import time | |||
| import click | |||
| from celery import shared_task | |||
| from flask import render_template | |||
| from extensions.ext_mail import mail | |||
| from models.account import Account, Tenant, TenantAccountJoin | |||
| from models.dataset import Dataset, DatasetAutoDisableLog | |||
| @shared_task(queue="mail") | |||
| def send_document_clean_notify_task(): | |||
| """ | |||
| Async Send document clean notify mail | |||
| Usage: send_document_clean_notify_task.delay() | |||
| """ | |||
| if not mail.is_inited(): | |||
| return | |||
| logging.info(click.style("Start send document clean notify mail", fg="green")) | |||
| start_at = time.perf_counter() | |||
| # send document clean notify mail | |||
| try: | |||
| dataset_auto_disable_logs = DatasetAutoDisableLog.query.filter(DatasetAutoDisableLog.notified == False).all() | |||
| # group by tenant_id | |||
| dataset_auto_disable_logs_map = {} | |||
| for dataset_auto_disable_log in dataset_auto_disable_logs: | |||
| dataset_auto_disable_logs_map[dataset_auto_disable_log.tenant_id].append(dataset_auto_disable_log) | |||
| for tenant_id, tenant_dataset_auto_disable_logs in dataset_auto_disable_logs_map.items(): | |||
| knowledge_details = [] | |||
| tenant = Tenant.query.filter(Tenant.id == tenant_id).first() | |||
| if not tenant: | |||
| continue | |||
| current_owner_join = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, role="owner").first() | |||
| account = Account.query.filter(Account.id == current_owner_join.account_id).first() | |||
| if not account: | |||
| continue | |||
| dataset_auto_dataset_map = {} | |||
| for dataset_auto_disable_log in tenant_dataset_auto_disable_logs: | |||
| dataset_auto_dataset_map[dataset_auto_disable_log.dataset_id].append( | |||
| dataset_auto_disable_log.document_id | |||
| ) | |||
| for dataset_id, document_ids in dataset_auto_dataset_map.items(): | |||
| dataset = Dataset.query.filter(Dataset.id == dataset_id).first() | |||
| if dataset: | |||
| document_count = len(document_ids) | |||
| knowledge_details.append(f"<li>Knowledge base {dataset.name}: {document_count} documents</li>") | |||
| html_content = render_template( | |||
| "clean_document_job_mail_template-US.html", | |||
| ) | |||
| mail.send(to=to, subject="立即加入 Dify 工作空间", html=html_content) | |||
| end_at = time.perf_counter() | |||
| logging.info( | |||
| click.style("Send document clean notify mail succeeded: latency: {}".format(end_at - start_at), fg="green") | |||
| ) | |||
| except Exception: | |||
| logging.exception("Send invite member mail to {} failed".format(to)) | |||
| @@ -1,4 +1,5 @@ | |||
| from typing import Optional | |||
| from enum import Enum | |||
| from typing import Literal, Optional | |||
| from pydantic import BaseModel | |||
| @@ -8,3 +9,112 @@ class SegmentUpdateEntity(BaseModel): | |||
| answer: Optional[str] = None | |||
| keywords: Optional[list[str]] = None | |||
| enabled: Optional[bool] = None | |||
| class ParentMode(str, Enum): | |||
| FULL_DOC = "full-doc" | |||
| PARAGRAPH = "paragraph" | |||
| class NotionIcon(BaseModel): | |||
| type: str | |||
| url: Optional[str] = None | |||
| emoji: Optional[str] = None | |||
| class NotionPage(BaseModel): | |||
| page_id: str | |||
| page_name: str | |||
| page_icon: Optional[NotionIcon] = None | |||
| type: str | |||
| class NotionInfo(BaseModel): | |||
| workspace_id: str | |||
| pages: list[NotionPage] | |||
| class WebsiteInfo(BaseModel): | |||
| provider: str | |||
| job_id: str | |||
| urls: list[str] | |||
| only_main_content: bool = True | |||
| class FileInfo(BaseModel): | |||
| file_ids: list[str] | |||
| class InfoList(BaseModel): | |||
| data_source_type: Literal["upload_file", "notion_import", "website_crawl"] | |||
| notion_info_list: Optional[list[NotionInfo]] = None | |||
| file_info_list: Optional[FileInfo] = None | |||
| website_info_list: Optional[WebsiteInfo] = None | |||
| class DataSource(BaseModel): | |||
| info_list: InfoList | |||
| class PreProcessingRule(BaseModel): | |||
| id: str | |||
| enabled: bool | |||
| class Segmentation(BaseModel): | |||
| separator: str = "\n" | |||
| max_tokens: int | |||
| chunk_overlap: int = 0 | |||
| class Rule(BaseModel): | |||
| pre_processing_rules: Optional[list[PreProcessingRule]] = None | |||
| segmentation: Optional[Segmentation] = None | |||
| parent_mode: Optional[Literal["full-doc", "paragraph"]] = None | |||
| subchunk_segmentation: Optional[Segmentation] = None | |||
| class ProcessRule(BaseModel): | |||
| mode: Literal["automatic", "custom", "hierarchical"] | |||
| rules: Optional[Rule] = None | |||
| class RerankingModel(BaseModel): | |||
| reranking_provider_name: Optional[str] = None | |||
| reranking_model_name: Optional[str] = None | |||
| class RetrievalModel(BaseModel): | |||
| search_method: Literal["hybrid_search", "semantic_search", "full_text_search"] | |||
| reranking_enable: bool | |||
| reranking_model: Optional[RerankingModel] = None | |||
| top_k: int | |||
| score_threshold_enabled: bool | |||
| score_threshold: Optional[float] = None | |||
| class KnowledgeConfig(BaseModel): | |||
| original_document_id: Optional[str] = None | |||
| duplicate: bool = True | |||
| indexing_technique: Literal["high_quality", "economy"] | |||
| data_source: Optional[DataSource] = None | |||
| process_rule: Optional[ProcessRule] = None | |||
| retrieval_model: Optional[RetrievalModel] = None | |||
| doc_form: str = "text_model" | |||
| doc_language: str = "English" | |||
| embedding_model: Optional[str] = None | |||
| embedding_model_provider: Optional[str] = None | |||
| name: Optional[str] = None | |||
| class SegmentUpdateArgs(BaseModel): | |||
| content: Optional[str] = None | |||
| answer: Optional[str] = None | |||
| keywords: Optional[list[str]] = None | |||
| regenerate_child_chunks: bool = False | |||
| enabled: Optional[bool] = None | |||
| class ChildChunkUpdateArgs(BaseModel): | |||
| id: Optional[str] = None | |||
| content: str | |||
| @@ -0,0 +1,9 @@ | |||
| from services.errors.base import BaseServiceError | |||
| class ChildChunkIndexingError(BaseServiceError): | |||
| description = "{message}" | |||
| class ChildChunkDeleteIndexError(BaseServiceError): | |||
| description = "{message}" | |||
| @@ -7,7 +7,7 @@ from core.rag.models.document import Document | |||
| from core.rag.retrieval.retrieval_methods import RetrievalMethod | |||
| from extensions.ext_database import db | |||
| from models.account import Account | |||
| from models.dataset import Dataset, DatasetQuery, DocumentSegment | |||
| from models.dataset import Dataset, DatasetQuery | |||
| default_retrieval_model = { | |||
| "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, | |||
| @@ -69,7 +69,7 @@ class HitTestingService: | |||
| db.session.add(dataset_query) | |||
| db.session.commit() | |||
| return dict(cls.compact_retrieve_response(dataset, query, all_documents)) | |||
| return cls.compact_retrieve_response(query, all_documents) | |||
| @classmethod | |||
| def external_retrieve( | |||
| @@ -106,41 +106,14 @@ class HitTestingService: | |||
| return dict(cls.compact_external_retrieve_response(dataset, query, all_documents)) | |||
| @classmethod | |||
| def compact_retrieve_response(cls, dataset: Dataset, query: str, documents: list[Document]): | |||
| records = [] | |||
| for document in documents: | |||
| if document.metadata is None: | |||
| continue | |||
| index_node_id = document.metadata["doc_id"] | |||
| segment = ( | |||
| db.session.query(DocumentSegment) | |||
| .filter( | |||
| DocumentSegment.dataset_id == dataset.id, | |||
| DocumentSegment.enabled == True, | |||
| DocumentSegment.status == "completed", | |||
| DocumentSegment.index_node_id == index_node_id, | |||
| ) | |||
| .first() | |||
| ) | |||
| if not segment: | |||
| continue | |||
| record = { | |||
| "segment": segment, | |||
| "score": document.metadata.get("score", None), | |||
| } | |||
| records.append(record) | |||
| def compact_retrieve_response(cls, query: str, documents: list[Document]): | |||
| records = RetrievalService.format_retrieval_documents(documents) | |||
| return { | |||
| "query": { | |||
| "content": query, | |||
| }, | |||
| "records": records, | |||
| "records": [record.model_dump() for record in records], | |||
| } | |||
| @classmethod | |||
| @@ -1,40 +1,68 @@ | |||
| from typing import Optional | |||
| from core.model_manager import ModelInstance, ModelManager | |||
| from core.model_runtime.entities.model_entities import ModelType | |||
| from core.rag.datasource.keyword.keyword_factory import Keyword | |||
| from core.rag.datasource.vdb.vector_factory import Vector | |||
| from core.rag.index_processor.constant.index_type import IndexType | |||
| from core.rag.index_processor.index_processor_factory import IndexProcessorFactory | |||
| from core.rag.models.document import Document | |||
| from models.dataset import Dataset, DocumentSegment | |||
| from extensions.ext_database import db | |||
| from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment | |||
| from models.dataset import Document as DatasetDocument | |||
| from services.entities.knowledge_entities.knowledge_entities import ParentMode | |||
| class VectorService: | |||
| @classmethod | |||
| def create_segments_vector( | |||
| cls, keywords_list: Optional[list[list[str]]], segments: list[DocumentSegment], dataset: Dataset | |||
| cls, keywords_list: Optional[list[list[str]]], segments: list[DocumentSegment], dataset: Dataset, doc_form: str | |||
| ): | |||
| documents = [] | |||
| for segment in segments: | |||
| document = Document( | |||
| page_content=segment.content, | |||
| metadata={ | |||
| "doc_id": segment.index_node_id, | |||
| "doc_hash": segment.index_node_hash, | |||
| "document_id": segment.document_id, | |||
| "dataset_id": segment.dataset_id, | |||
| }, | |||
| ) | |||
| documents.append(document) | |||
| if dataset.indexing_technique == "high_quality": | |||
| # save vector index | |||
| vector = Vector(dataset=dataset) | |||
| vector.add_texts(documents, duplicate_check=True) | |||
| # save keyword index | |||
| keyword = Keyword(dataset) | |||
| for segment in segments: | |||
| if doc_form == IndexType.PARENT_CHILD_INDEX: | |||
| document = DatasetDocument.query.filter_by(id=segment.document_id).first() | |||
| # get the process rule | |||
| processing_rule = ( | |||
| db.session.query(DatasetProcessRule) | |||
| .filter(DatasetProcessRule.id == document.dataset_process_rule_id) | |||
| .first() | |||
| ) | |||
| # get embedding model instance | |||
| if dataset.indexing_technique == "high_quality": | |||
| # check embedding model setting | |||
| model_manager = ModelManager() | |||
| if keywords_list and len(keywords_list) > 0: | |||
| keyword.add_texts(documents, keywords_list=keywords_list) | |||
| else: | |||
| keyword.add_texts(documents) | |||
| if dataset.embedding_model_provider: | |||
| embedding_model_instance = model_manager.get_model_instance( | |||
| tenant_id=dataset.tenant_id, | |||
| provider=dataset.embedding_model_provider, | |||
| model_type=ModelType.TEXT_EMBEDDING, | |||
| model=dataset.embedding_model, | |||
| ) | |||
| else: | |||
| embedding_model_instance = model_manager.get_default_model_instance( | |||
| tenant_id=dataset.tenant_id, | |||
| model_type=ModelType.TEXT_EMBEDDING, | |||
| ) | |||
| else: | |||
| raise ValueError("The knowledge base index technique is not high quality!") | |||
| cls.generate_child_chunks(segment, document, dataset, embedding_model_instance, processing_rule, False) | |||
| else: | |||
| document = Document( | |||
| page_content=segment.content, | |||
| metadata={ | |||
| "doc_id": segment.index_node_id, | |||
| "doc_hash": segment.index_node_hash, | |||
| "document_id": segment.document_id, | |||
| "dataset_id": segment.dataset_id, | |||
| }, | |||
| ) | |||
| documents.append(document) | |||
| if len(documents) > 0: | |||
| index_processor = IndexProcessorFactory(doc_form).init_index_processor() | |||
| index_processor.load(dataset, documents, with_keywords=True, keywords_list=keywords_list) | |||
| @classmethod | |||
| def update_segment_vector(cls, keywords: Optional[list[str]], segment: DocumentSegment, dataset: Dataset): | |||
| @@ -65,3 +93,123 @@ class VectorService: | |||
| keyword.add_texts([document], keywords_list=[keywords]) | |||
| else: | |||
| keyword.add_texts([document]) | |||
| @classmethod | |||
| def generate_child_chunks( | |||
| cls, | |||
| segment: DocumentSegment, | |||
| dataset_document: Document, | |||
| dataset: Dataset, | |||
| embedding_model_instance: ModelInstance, | |||
| processing_rule: DatasetProcessRule, | |||
| regenerate: bool = False, | |||
| ): | |||
| index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor() | |||
| if regenerate: | |||
| # delete child chunks | |||
| index_processor.clean(dataset, [segment.index_node_id], with_keywords=True, delete_child_chunks=True) | |||
| # generate child chunks | |||
| document = Document( | |||
| page_content=segment.content, | |||
| metadata={ | |||
| "doc_id": segment.index_node_id, | |||
| "doc_hash": segment.index_node_hash, | |||
| "document_id": segment.document_id, | |||
| "dataset_id": segment.dataset_id, | |||
| }, | |||
| ) | |||
| # use full doc mode to generate segment's child chunk | |||
| processing_rule_dict = processing_rule.to_dict() | |||
| processing_rule_dict["rules"]["parent_mode"] = ParentMode.FULL_DOC.value | |||
| documents = index_processor.transform( | |||
| [document], | |||
| embedding_model_instance=embedding_model_instance, | |||
| process_rule=processing_rule_dict, | |||
| tenant_id=dataset.tenant_id, | |||
| doc_language=dataset_document.doc_language, | |||
| ) | |||
| # save child chunks | |||
| if len(documents) > 0 and len(documents[0].children) > 0: | |||
| index_processor.load(dataset, documents) | |||
| for position, child_chunk in enumerate(documents[0].children, start=1): | |||
| child_segment = ChildChunk( | |||
| tenant_id=dataset.tenant_id, | |||
| dataset_id=dataset.id, | |||
| document_id=dataset_document.id, | |||
| segment_id=segment.id, | |||
| position=position, | |||
| index_node_id=child_chunk.metadata["doc_id"], | |||
| index_node_hash=child_chunk.metadata["doc_hash"], | |||
| content=child_chunk.page_content, | |||
| word_count=len(child_chunk.page_content), | |||
| type="automatic", | |||
| created_by=dataset_document.created_by, | |||
| ) | |||
| db.session.add(child_segment) | |||
| db.session.commit() | |||
| @classmethod | |||
| def create_child_chunk_vector(cls, child_segment: ChildChunk, dataset: Dataset): | |||
| child_document = Document( | |||
| page_content=child_segment.content, | |||
| metadata={ | |||
| "doc_id": child_segment.index_node_id, | |||
| "doc_hash": child_segment.index_node_hash, | |||
| "document_id": child_segment.document_id, | |||
| "dataset_id": child_segment.dataset_id, | |||
| }, | |||
| ) | |||
| if dataset.indexing_technique == "high_quality": | |||
| # save vector index | |||
| vector = Vector(dataset=dataset) | |||
| vector.add_texts([child_document], duplicate_check=True) | |||
| @classmethod | |||
| def update_child_chunk_vector( | |||
| cls, | |||
| new_child_chunks: list[ChildChunk], | |||
| update_child_chunks: list[ChildChunk], | |||
| delete_child_chunks: list[ChildChunk], | |||
| dataset: Dataset, | |||
| ): | |||
| documents = [] | |||
| delete_node_ids = [] | |||
| for new_child_chunk in new_child_chunks: | |||
| new_child_document = Document( | |||
| page_content=new_child_chunk.content, | |||
| metadata={ | |||
| "doc_id": new_child_chunk.index_node_id, | |||
| "doc_hash": new_child_chunk.index_node_hash, | |||
| "document_id": new_child_chunk.document_id, | |||
| "dataset_id": new_child_chunk.dataset_id, | |||
| }, | |||
| ) | |||
| documents.append(new_child_document) | |||
| for update_child_chunk in update_child_chunks: | |||
| child_document = Document( | |||
| page_content=update_child_chunk.content, | |||
| metadata={ | |||
| "doc_id": update_child_chunk.index_node_id, | |||
| "doc_hash": update_child_chunk.index_node_hash, | |||
| "document_id": update_child_chunk.document_id, | |||
| "dataset_id": update_child_chunk.dataset_id, | |||
| }, | |||
| ) | |||
| documents.append(child_document) | |||
| delete_node_ids.append(update_child_chunk.index_node_id) | |||
| for delete_child_chunk in delete_child_chunks: | |||
| delete_node_ids.append(delete_child_chunk.index_node_id) | |||
| if dataset.indexing_technique == "high_quality": | |||
| # update vector index | |||
| vector = Vector(dataset=dataset) | |||
| if delete_node_ids: | |||
| vector.delete_by_ids(delete_node_ids) | |||
| if documents: | |||
| vector.add_texts(documents, duplicate_check=True) | |||
| @classmethod | |||
| def delete_child_chunk_vector(cls, child_chunk: ChildChunk, dataset: Dataset): | |||
| vector = Vector(dataset=dataset) | |||
| vector.delete_by_ids([child_chunk.index_node_id]) | |||
| @@ -6,12 +6,13 @@ import click | |||
| from celery import shared_task # type: ignore | |||
| from werkzeug.exceptions import NotFound | |||
| from core.rag.index_processor.constant.index_type import IndexType | |||
| from core.rag.index_processor.index_processor_factory import IndexProcessorFactory | |||
| from core.rag.models.document import Document | |||
| from core.rag.models.document import ChildDocument, Document | |||
| from extensions.ext_database import db | |||
| from extensions.ext_redis import redis_client | |||
| from models.dataset import DatasetAutoDisableLog, DocumentSegment | |||
| from models.dataset import Document as DatasetDocument | |||
| from models.dataset import DocumentSegment | |||
| @shared_task(queue="dataset") | |||
| @@ -53,7 +54,22 @@ def add_document_to_index_task(dataset_document_id: str): | |||
| "dataset_id": segment.dataset_id, | |||
| }, | |||
| ) | |||
| if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: | |||
| child_chunks = segment.child_chunks | |||
| if child_chunks: | |||
| child_documents = [] | |||
| for child_chunk in child_chunks: | |||
| child_document = ChildDocument( | |||
| page_content=child_chunk.content, | |||
| metadata={ | |||
| "doc_id": child_chunk.index_node_id, | |||
| "doc_hash": child_chunk.index_node_hash, | |||
| "document_id": segment.document_id, | |||
| "dataset_id": segment.dataset_id, | |||
| }, | |||
| ) | |||
| child_documents.append(child_document) | |||
| document.children = child_documents | |||
| documents.append(document) | |||
| dataset = dataset_document.dataset | |||
| @@ -65,6 +81,12 @@ def add_document_to_index_task(dataset_document_id: str): | |||
| index_processor = IndexProcessorFactory(index_type).init_index_processor() | |||
| index_processor.load(dataset, documents) | |||
| # delete auto disable log | |||
| db.session.query(DatasetAutoDisableLog).filter( | |||
| DatasetAutoDisableLog.document_id == dataset_document.id | |||
| ).delete() | |||
| db.session.commit() | |||
| end_at = time.perf_counter() | |||
| logging.info( | |||
| click.style( | |||
| @@ -0,0 +1,75 @@ | |||
| import logging | |||
| import time | |||
| import click | |||
| from celery import shared_task | |||
| from core.rag.index_processor.index_processor_factory import IndexProcessorFactory | |||
| from core.tools.utils.web_reader_tool import get_image_upload_file_ids | |||
| from extensions.ext_database import db | |||
| from extensions.ext_storage import storage | |||
| from models.dataset import Dataset, DocumentSegment | |||
| from models.model import UploadFile | |||
| @shared_task(queue="dataset") | |||
| def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form: str, file_ids: list[str]): | |||
| """ | |||
| Clean document when document deleted. | |||
| :param document_ids: document ids | |||
| :param dataset_id: dataset id | |||
| :param doc_form: doc_form | |||
| :param file_ids: file ids | |||
| Usage: clean_document_task.delay(document_id, dataset_id) | |||
| """ | |||
| logging.info(click.style("Start batch clean documents when documents deleted", fg="green")) | |||
| start_at = time.perf_counter() | |||
| try: | |||
| dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() | |||
| if not dataset: | |||
| raise Exception("Document has no dataset") | |||
| segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id.in_(document_ids)).all() | |||
| # check segment is exist | |||
| if segments: | |||
| index_node_ids = [segment.index_node_id for segment in segments] | |||
| index_processor = IndexProcessorFactory(doc_form).init_index_processor() | |||
| index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) | |||
| for segment in segments: | |||
| image_upload_file_ids = get_image_upload_file_ids(segment.content) | |||
| for upload_file_id in image_upload_file_ids: | |||
| image_file = db.session.query(UploadFile).filter(UploadFile.id == upload_file_id).first() | |||
| try: | |||
| storage.delete(image_file.key) | |||
| except Exception: | |||
| logging.exception( | |||
| "Delete image_files failed when storage deleted, \ | |||
| image_upload_file_is: {}".format(upload_file_id) | |||
| ) | |||
| db.session.delete(image_file) | |||
| db.session.delete(segment) | |||
| db.session.commit() | |||
| if file_ids: | |||
| files = db.session.query(UploadFile).filter(UploadFile.id.in_(file_ids)).all() | |||
| for file in files: | |||
| try: | |||
| storage.delete(file.key) | |||
| except Exception: | |||
| logging.exception("Delete file failed when document deleted, file_id: {}".format(file.id)) | |||
| db.session.delete(file) | |||
| db.session.commit() | |||
| end_at = time.perf_counter() | |||
| logging.info( | |||
| click.style( | |||
| "Cleaned documents when documents deleted latency: {}".format(end_at - start_at), | |||
| fg="green", | |||
| ) | |||
| ) | |||
| except Exception: | |||
| logging.exception("Cleaned documents when documents deleted failed") | |||
| @@ -7,13 +7,13 @@ import click | |||
| from celery import shared_task # type: ignore | |||
| from sqlalchemy import func | |||
| from core.indexing_runner import IndexingRunner | |||
| from core.model_manager import ModelManager | |||
| from core.model_runtime.entities.model_entities import ModelType | |||
| from extensions.ext_database import db | |||
| from extensions.ext_redis import redis_client | |||
| from libs import helper | |||
| from models.dataset import Dataset, Document, DocumentSegment | |||
| from services.vector_service import VectorService | |||
| @shared_task(queue="dataset") | |||
| @@ -96,8 +96,7 @@ def batch_create_segment_to_index_task( | |||
| dataset_document.word_count += word_count_change | |||
| db.session.add(dataset_document) | |||
| # add index to db | |||
| indexing_runner = IndexingRunner() | |||
| indexing_runner.batch_add_segments(document_segments, dataset) | |||
| VectorService.create_segments_vector(None, document_segments, dataset, dataset_document.doc_form) | |||
| db.session.commit() | |||
| redis_client.setex(indexing_cache_key, 600, "completed") | |||
| end_at = time.perf_counter() | |||
| @@ -62,7 +62,7 @@ def clean_dataset_task( | |||
| if doc_form is None: | |||
| raise ValueError("Index type must be specified.") | |||
| index_processor = IndexProcessorFactory(doc_form).init_index_processor() | |||
| index_processor.clean(dataset, None) | |||
| index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=True) | |||
| for document in documents: | |||
| db.session.delete(document) | |||
| @@ -38,7 +38,7 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i | |||
| if segments: | |||
| index_node_ids = [segment.index_node_id for segment in segments] | |||
| index_processor = IndexProcessorFactory(doc_form).init_index_processor() | |||
| index_processor.clean(dataset, index_node_ids) | |||
| index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) | |||
| for segment in segments: | |||
| image_upload_file_ids = get_image_upload_file_ids(segment.content) | |||
| @@ -37,7 +37,7 @@ def clean_notion_document_task(document_ids: list[str], dataset_id: str): | |||
| segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() | |||
| index_node_ids = [segment.index_node_id for segment in segments] | |||
| index_processor.clean(dataset, index_node_ids) | |||
| index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) | |||
| for segment in segments: | |||
| db.session.delete(segment) | |||
| @@ -4,8 +4,9 @@ import time | |||
| import click | |||
| from celery import shared_task # type: ignore | |||
| from core.rag.index_processor.constant.index_type import IndexType | |||
| from core.rag.index_processor.index_processor_factory import IndexProcessorFactory | |||
| from core.rag.models.document import Document | |||
| from core.rag.models.document import ChildDocument, Document | |||
| from extensions.ext_database import db | |||
| from models.dataset import Dataset, DocumentSegment | |||
| from models.dataset import Document as DatasetDocument | |||
| @@ -105,7 +106,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str): | |||
| db.session.commit() | |||
| # clean index | |||
| index_processor.clean(dataset, None, with_keywords=False) | |||
| index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False) | |||
| for dataset_document in dataset_documents: | |||
| # update from vector index | |||
| @@ -128,7 +129,22 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str): | |||
| "dataset_id": segment.dataset_id, | |||
| }, | |||
| ) | |||
| if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: | |||
| child_chunks = segment.child_chunks | |||
| if child_chunks: | |||
| child_documents = [] | |||
| for child_chunk in child_chunks: | |||
| child_document = ChildDocument( | |||
| page_content=child_chunk.content, | |||
| metadata={ | |||
| "doc_id": child_chunk.index_node_id, | |||
| "doc_hash": child_chunk.index_node_hash, | |||
| "document_id": segment.document_id, | |||
| "dataset_id": segment.dataset_id, | |||
| }, | |||
| ) | |||
| child_documents.append(child_document) | |||
| document.children = child_documents | |||
| documents.append(document) | |||
| # save vector index | |||
| index_processor.load(dataset, documents, with_keywords=False) | |||
| @@ -6,48 +6,38 @@ from celery import shared_task # type: ignore | |||
| from core.rag.index_processor.index_processor_factory import IndexProcessorFactory | |||
| from extensions.ext_database import db | |||
| from extensions.ext_redis import redis_client | |||
| from models.dataset import Dataset, Document | |||
| @shared_task(queue="dataset") | |||
| def delete_segment_from_index_task(segment_id: str, index_node_id: str, dataset_id: str, document_id: str): | |||
| def delete_segment_from_index_task(index_node_ids: list, dataset_id: str, document_id: str): | |||
| """ | |||
| Async Remove segment from index | |||
| :param segment_id: | |||
| :param index_node_id: | |||
| :param index_node_ids: | |||
| :param dataset_id: | |||
| :param document_id: | |||
| Usage: delete_segment_from_index_task.delay(segment_id) | |||
| Usage: delete_segment_from_index_task.delay(segment_ids) | |||
| """ | |||
| logging.info(click.style("Start delete segment from index: {}".format(segment_id), fg="green")) | |||
| logging.info(click.style("Start delete segment from index", fg="green")) | |||
| start_at = time.perf_counter() | |||
| indexing_cache_key = "segment_{}_delete_indexing".format(segment_id) | |||
| try: | |||
| dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() | |||
| if not dataset: | |||
| logging.info(click.style("Segment {} has no dataset, pass.".format(segment_id), fg="cyan")) | |||
| return | |||
| dataset_document = db.session.query(Document).filter(Document.id == document_id).first() | |||
| if not dataset_document: | |||
| logging.info(click.style("Segment {} has no document, pass.".format(segment_id), fg="cyan")) | |||
| return | |||
| if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed": | |||
| logging.info(click.style("Segment {} document status is invalid, pass.".format(segment_id), fg="cyan")) | |||
| return | |||
| index_type = dataset_document.doc_form | |||
| index_processor = IndexProcessorFactory(index_type).init_index_processor() | |||
| index_processor.clean(dataset, [index_node_id]) | |||
| index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) | |||
| end_at = time.perf_counter() | |||
| logging.info( | |||
| click.style("Segment deleted from index: {} latency: {}".format(segment_id, end_at - start_at), fg="green") | |||
| ) | |||
| logging.info(click.style("Segment deleted from index latency: {}".format(end_at - start_at), fg="green")) | |||
| except Exception: | |||
| logging.exception("delete segment from index failed") | |||
| finally: | |||
| redis_client.delete(indexing_cache_key) | |||
| @@ -0,0 +1,76 @@ | |||
| import logging | |||
| import time | |||
| import click | |||
| from celery import shared_task | |||
| from core.rag.index_processor.index_processor_factory import IndexProcessorFactory | |||
| from extensions.ext_database import db | |||
| from extensions.ext_redis import redis_client | |||
| from models.dataset import Dataset, DocumentSegment | |||
| from models.dataset import Document as DatasetDocument | |||
| @shared_task(queue="dataset") | |||
| def disable_segments_from_index_task(segment_ids: list, dataset_id: str, document_id: str): | |||
| """ | |||
| Async disable segments from index | |||
| :param segment_ids: | |||
| Usage: disable_segments_from_index_task.delay(segment_ids, dataset_id, document_id) | |||
| """ | |||
| start_at = time.perf_counter() | |||
| dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() | |||
| if not dataset: | |||
| logging.info(click.style("Dataset {} not found, pass.".format(dataset_id), fg="cyan")) | |||
| return | |||
| dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == document_id).first() | |||
| if not dataset_document: | |||
| logging.info(click.style("Document {} not found, pass.".format(document_id), fg="cyan")) | |||
| return | |||
| if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed": | |||
| logging.info(click.style("Document {} status is invalid, pass.".format(document_id), fg="cyan")) | |||
| return | |||
| # sync index processor | |||
| index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor() | |||
| segments = ( | |||
| db.session.query(DocumentSegment) | |||
| .filter( | |||
| DocumentSegment.id.in_(segment_ids), | |||
| DocumentSegment.dataset_id == dataset_id, | |||
| DocumentSegment.document_id == document_id, | |||
| ) | |||
| .all() | |||
| ) | |||
| if not segments: | |||
| return | |||
| try: | |||
| index_node_ids = [segment.index_node_id for segment in segments] | |||
| index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=False) | |||
| end_at = time.perf_counter() | |||
| logging.info(click.style("Segments removed from index latency: {}".format(end_at - start_at), fg="green")) | |||
| except Exception: | |||
| # update segment error msg | |||
| db.session.query(DocumentSegment).filter( | |||
| DocumentSegment.id.in_(segment_ids), | |||
| DocumentSegment.dataset_id == dataset_id, | |||
| DocumentSegment.document_id == document_id, | |||
| ).update( | |||
| { | |||
| "disabled_at": None, | |||
| "disabled_by": None, | |||
| "enabled": True, | |||
| } | |||
| ) | |||
| db.session.commit() | |||
| finally: | |||
| for segment in segments: | |||
| indexing_cache_key = "segment_{}_indexing".format(segment.id) | |||
| redis_client.delete(indexing_cache_key) | |||
| @@ -82,7 +82,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): | |||
| index_node_ids = [segment.index_node_id for segment in segments] | |||
| # delete from vector index | |||
| index_processor.clean(dataset, index_node_ids) | |||
| index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) | |||
| for segment in segments: | |||
| db.session.delete(segment) | |||
| @@ -47,7 +47,7 @@ def document_indexing_update_task(dataset_id: str, document_id: str): | |||
| index_node_ids = [segment.index_node_id for segment in segments] | |||
| # delete from vector index | |||
| index_processor.clean(dataset, index_node_ids) | |||
| index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) | |||
| for segment in segments: | |||
| db.session.delete(segment) | |||
| @@ -51,7 +51,7 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list): | |||
| if document: | |||
| document.indexing_status = "error" | |||
| document.error = str(e) | |||
| document.stopped_at = datetime.datetime.utcnow() | |||
| document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) | |||
| db.session.add(document) | |||
| db.session.commit() | |||
| return | |||
| @@ -73,14 +73,14 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list): | |||
| index_node_ids = [segment.index_node_id for segment in segments] | |||
| # delete from vector index | |||
| index_processor.clean(dataset, index_node_ids) | |||
| index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) | |||
| for segment in segments: | |||
| db.session.delete(segment) | |||
| db.session.commit() | |||
| document.indexing_status = "parsing" | |||
| document.processing_started_at = datetime.datetime.utcnow() | |||
| document.processing_started_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) | |||
| documents.append(document) | |||
| db.session.add(document) | |||
| db.session.commit() | |||
| @@ -6,8 +6,9 @@ import click | |||
| from celery import shared_task # type: ignore | |||
| from werkzeug.exceptions import NotFound | |||
| from core.rag.index_processor.constant.index_type import IndexType | |||
| from core.rag.index_processor.index_processor_factory import IndexProcessorFactory | |||
| from core.rag.models.document import Document | |||
| from core.rag.models.document import ChildDocument, Document | |||
| from extensions.ext_database import db | |||
| from extensions.ext_redis import redis_client | |||
| from models.dataset import DocumentSegment | |||
| @@ -61,6 +62,22 @@ def enable_segment_to_index_task(segment_id: str): | |||
| return | |||
| index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor() | |||
| if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: | |||
| child_chunks = segment.child_chunks | |||
| if child_chunks: | |||
| child_documents = [] | |||
| for child_chunk in child_chunks: | |||
| child_document = ChildDocument( | |||
| page_content=child_chunk.content, | |||
| metadata={ | |||
| "doc_id": child_chunk.index_node_id, | |||
| "doc_hash": child_chunk.index_node_hash, | |||
| "document_id": segment.document_id, | |||
| "dataset_id": segment.dataset_id, | |||
| }, | |||
| ) | |||
| child_documents.append(child_document) | |||
| document.children = child_documents | |||
| # save vector index | |||
| index_processor.load(dataset, [document]) | |||
| @@ -0,0 +1,108 @@ | |||
| import datetime | |||
| import logging | |||
| import time | |||
| import click | |||
| from celery import shared_task | |||
| from core.rag.index_processor.constant.index_type import IndexType | |||
| from core.rag.index_processor.index_processor_factory import IndexProcessorFactory | |||
| from core.rag.models.document import ChildDocument, Document | |||
| from extensions.ext_database import db | |||
| from extensions.ext_redis import redis_client | |||
| from models.dataset import Dataset, DocumentSegment | |||
| from models.dataset import Document as DatasetDocument | |||
| @shared_task(queue="dataset") | |||
| def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_id: str): | |||
| """ | |||
| Async enable segments to index | |||
| :param segment_ids: | |||
| Usage: enable_segments_to_index_task.delay(segment_ids) | |||
| """ | |||
| start_at = time.perf_counter() | |||
| dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() | |||
| if not dataset: | |||
| logging.info(click.style("Dataset {} not found, pass.".format(dataset_id), fg="cyan")) | |||
| return | |||
| dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == document_id).first() | |||
| if not dataset_document: | |||
| logging.info(click.style("Document {} not found, pass.".format(document_id), fg="cyan")) | |||
| return | |||
| if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed": | |||
| logging.info(click.style("Document {} status is invalid, pass.".format(document_id), fg="cyan")) | |||
| return | |||
| # sync index processor | |||
| index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor() | |||
| segments = ( | |||
| db.session.query(DocumentSegment) | |||
| .filter( | |||
| DocumentSegment.id.in_(segment_ids), | |||
| DocumentSegment.dataset_id == dataset_id, | |||
| DocumentSegment.document_id == document_id, | |||
| ) | |||
| .all() | |||
| ) | |||
| if not segments: | |||
| return | |||
| try: | |||
| documents = [] | |||
| for segment in segments: | |||
| document = Document( | |||
| page_content=segment.content, | |||
| metadata={ | |||
| "doc_id": segment.index_node_id, | |||
| "doc_hash": segment.index_node_hash, | |||
| "document_id": document_id, | |||
| "dataset_id": dataset_id, | |||
| }, | |||
| ) | |||
| if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: | |||
| child_chunks = segment.child_chunks | |||
| if child_chunks: | |||
| child_documents = [] | |||
| for child_chunk in child_chunks: | |||
| child_document = ChildDocument( | |||
| page_content=child_chunk.content, | |||
| metadata={ | |||
| "doc_id": child_chunk.index_node_id, | |||
| "doc_hash": child_chunk.index_node_hash, | |||
| "document_id": document_id, | |||
| "dataset_id": dataset_id, | |||
| }, | |||
| ) | |||
| child_documents.append(child_document) | |||
| document.children = child_documents | |||
| documents.append(document) | |||
| # save vector index | |||
| index_processor.load(dataset, documents) | |||
| end_at = time.perf_counter() | |||
| logging.info(click.style("Segments enabled to index latency: {}".format(end_at - start_at), fg="green")) | |||
| except Exception as e: | |||
| logging.exception("enable segments to index failed") | |||
| # update segment error msg | |||
| db.session.query(DocumentSegment).filter( | |||
| DocumentSegment.id.in_(segment_ids), | |||
| DocumentSegment.dataset_id == dataset_id, | |||
| DocumentSegment.document_id == document_id, | |||
| ).update( | |||
| { | |||
| "error": str(e), | |||
| "status": "error", | |||
| "disabled_at": datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), | |||
| "enabled": False, | |||
| } | |||
| ) | |||
| db.session.commit() | |||
| finally: | |||
| for segment in segments: | |||
| indexing_cache_key = "segment_{}_indexing".format(segment.id) | |||
| redis_client.delete(indexing_cache_key) | |||
| @@ -43,7 +43,7 @@ def remove_document_from_index_task(document_id: str): | |||
| index_node_ids = [segment.index_node_id for segment in segments] | |||
| if index_node_ids: | |||
| try: | |||
| index_processor.clean(dataset, index_node_ids) | |||
| index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=False) | |||
| except Exception: | |||
| logging.exception(f"clean dataset {dataset.id} from index failed") | |||
| @@ -48,7 +48,7 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]): | |||
| if document: | |||
| document.indexing_status = "error" | |||
| document.error = str(e) | |||
| document.stopped_at = datetime.datetime.utcnow() | |||
| document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) | |||
| db.session.add(document) | |||
| db.session.commit() | |||
| redis_client.delete(retry_indexing_cache_key) | |||
| @@ -69,14 +69,14 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]): | |||
| if segments: | |||
| index_node_ids = [segment.index_node_id for segment in segments] | |||
| # delete from vector index | |||
| index_processor.clean(dataset, index_node_ids) | |||
| index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) | |||
| for segment in segments: | |||
| db.session.delete(segment) | |||
| db.session.commit() | |||
| for segment in segments: | |||
| db.session.delete(segment) | |||
| db.session.commit() | |||
| document.indexing_status = "parsing" | |||
| document.processing_started_at = datetime.datetime.utcnow() | |||
| document.processing_started_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) | |||
| db.session.add(document) | |||
| db.session.commit() | |||
| @@ -86,7 +86,7 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]): | |||
| except Exception as ex: | |||
| document.indexing_status = "error" | |||
| document.error = str(ex) | |||
| document.stopped_at = datetime.datetime.utcnow() | |||
| document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) | |||
| db.session.add(document) | |||
| db.session.commit() | |||
| logging.info(click.style(str(ex), fg="yellow")) | |||
| @@ -46,7 +46,7 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str): | |||
| if document: | |||
| document.indexing_status = "error" | |||
| document.error = str(e) | |||
| document.stopped_at = datetime.datetime.utcnow() | |||
| document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) | |||
| db.session.add(document) | |||
| db.session.commit() | |||
| redis_client.delete(sync_indexing_cache_key) | |||
| @@ -65,14 +65,14 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str): | |||
| if segments: | |||
| index_node_ids = [segment.index_node_id for segment in segments] | |||
| # delete from vector index | |||
| index_processor.clean(dataset, index_node_ids) | |||
| index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) | |||
| for segment in segments: | |||
| db.session.delete(segment) | |||
| db.session.commit() | |||
| for segment in segments: | |||
| db.session.delete(segment) | |||
| db.session.commit() | |||
| document.indexing_status = "parsing" | |||
| document.processing_started_at = datetime.datetime.utcnow() | |||
| document.processing_started_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) | |||
| db.session.add(document) | |||
| db.session.commit() | |||
| @@ -82,7 +82,7 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str): | |||
| except Exception as ex: | |||
| document.indexing_status = "error" | |||
| document.error = str(ex) | |||
| document.stopped_at = datetime.datetime.utcnow() | |||
| document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) | |||
| db.session.add(document) | |||
| db.session.commit() | |||
| logging.info(click.style(str(ex), fg="yellow")) | |||
| @@ -0,0 +1,98 @@ | |||
| <!DOCTYPE html> | |||
| <html lang="en"> | |||
| <head> | |||
| <meta charset="UTF-8"> | |||
| <meta name="viewport" content="width=device-width, initial-scale=1.0"> | |||
| <title>Documents Disabled Notification</title> | |||
| <style> | |||
| body { | |||
| font-family: Arial, sans-serif; | |||
| margin: 0; | |||
| padding: 0; | |||
| background-color: #f5f5f5; | |||
| } | |||
| .email-container { | |||
| max-width: 600px; | |||
| margin: 20px auto; | |||
| background: #ffffff; | |||
| border-radius: 10px; | |||
| box-shadow: 0 2px 10px rgba(0, 0, 0, 0.1); | |||
| overflow: hidden; | |||
| } | |||
| .header { | |||
| background-color: #eef2fa; | |||
| padding: 20px; | |||
| text-align: center; | |||
| } | |||
| .header img { | |||
| height: 40px; | |||
| } | |||
| .content { | |||
| padding: 20px; | |||
| line-height: 1.6; | |||
| color: #333; | |||
| } | |||
| .content h1 { | |||
| font-size: 24px; | |||
| color: #222; | |||
| } | |||
| .content p { | |||
| margin: 10px 0; | |||
| } | |||
| .content ul { | |||
| padding-left: 20px; | |||
| } | |||
| .content ul li { | |||
| margin-bottom: 10px; | |||
| } | |||
| .cta-button { | |||
| display: block; | |||
| margin: 20px auto; | |||
| padding: 10px 20px; | |||
| background-color: #4e89f9; | |||
| color: #ffffff; | |||
| text-align: center; | |||
| text-decoration: none; | |||
| border-radius: 5px; | |||
| width: fit-content; | |||
| } | |||
| .footer { | |||
| text-align: center; | |||
| padding: 10px; | |||
| font-size: 12px; | |||
| color: #777; | |||
| background-color: #f9f9f9; | |||
| } | |||
| </style> | |||
| </head> | |||
| <body> | |||
| <div class="email-container"> | |||
| <!-- Header --> | |||
| <div class="header"> | |||
| <img src="https://via.placeholder.com/150x40?text=Dify" alt="Dify Logo"> | |||
| </div> | |||
| <!-- Content --> | |||
| <div class="content"> | |||
| <h1>Some Documents in Your Knowledge Base Have Been Disabled</h1> | |||
| <p>Dear {{userName}},</p> | |||
| <p> | |||
| We're sorry for the inconvenience. To ensure optimal performance, documents | |||
| that haven’t been updated or accessed in the past 7 days have been disabled in | |||
| your knowledge bases: | |||
| </p> | |||
| <ul> | |||
| {{knowledge_details}} | |||
| </ul> | |||
| <p>You can re-enable them anytime.</p> | |||
| <a href={{url}} class="cta-button">Re-enable in Dify</a> | |||
| </div> | |||
| <!-- Footer --> | |||
| <div class="footer"> | |||
| Sincerely,<br> | |||
| The Dify Team | |||
| </div> | |||
| </div> | |||
| </body> | |||
| </html> | |||