Pārlūkot izejas kodu

Feat/support parent child chunk (#12092)

tags/0.15.0
Jyong pirms 10 mēnešiem
vecāks
revīzija
9231fdbf4c
Revīzijas autora e-pasta adrese nav piesaistīta nevienam kontam
54 mainītis faili ar 2576 papildinājumiem un 806 dzēšanām
  1. 1
    1
      api/controllers/console/datasets/data_source.py
  2. 14
    1
      api/controllers/console/datasets/datasets.py
  3. 99
    79
      api/controllers/console/datasets/datasets_document.py
  4. 317
    82
      api/controllers/console/datasets/datasets_segments.py
  5. 12
    0
      api/controllers/console/datasets/error.py
  6. 2
    1
      api/controllers/service_api/dataset/segment.py
  7. 19
    0
      api/core/entities/knowledge_entities.py
  8. 68
    212
      api/core/indexing_runner.py
  9. 89
    1
      api/core/rag/datasource/retrieval_service.py
  10. 43
    2
      api/core/rag/docstore/dataset_docstore.py
  11. 23
    0
      api/core/rag/embedding/retrieval.py
  12. 1
    6
      api/core/rag/extractor/extract_processor.py
  13. 3
    1
      api/core/rag/extractor/word_extractor.py
  14. 2
    3
      api/core/rag/index_processor/constant/index_type.py
  15. 14
    11
      api/core/rag/index_processor/index_processor_base.py
  16. 5
    2
      api/core/rag/index_processor/index_processor_factory.py
  17. 23
    6
      api/core/rag/index_processor/processor/paragraph_index_processor.py
  18. 189
    0
      api/core/rag/index_processor/processor/parent_child_index_processor.py
  19. 41
    22
      api/core/rag/index_processor/processor/qa_index_processor.py
  20. 15
    0
      api/core/rag/models/document.py
  21. 9
    23
      api/core/rag/retrieval/dataset_retrieval.py
  22. 8
    26
      api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py
  23. 1
    0
      api/fields/dataset_fields.py
  24. 1
    0
      api/fields/document_fields.py
  25. 8
    0
      api/fields/hit_testing_fields.py
  26. 14
    0
      api/fields/segment_fields.py
  27. 55
    0
      api/migrations/versions/2024_11_22_0701-e19037032219_parent_child_index.py
  28. 47
    0
      api/migrations/versions/2024_12_25_1137-923752d42eb6_add_auto_disabled_dataset_logs.py
  29. 84
    3
      api/models/dataset.py
  30. 35
    1
      api/schedule/clean_unused_datasets_task.py
  31. 66
    0
      api/schedule/mail_clean_document_notify_task.py
  32. 522
    218
      api/services/dataset_service.py
  33. 111
    1
      api/services/entities/knowledge_entities/knowledge_entities.py
  34. 9
    0
      api/services/errors/chunk.py
  35. 5
    32
      api/services/hit_testing_service.py
  36. 171
    23
      api/services/vector_service.py
  37. 25
    3
      api/tasks/add_document_to_index_task.py
  38. 75
    0
      api/tasks/batch_clean_document_task.py
  39. 2
    3
      api/tasks/batch_create_segment_to_index_task.py
  40. 1
    1
      api/tasks/clean_dataset_task.py
  41. 1
    1
      api/tasks/clean_document_task.py
  42. 1
    1
      api/tasks/clean_notion_document_task.py
  43. 19
    3
      api/tasks/deal_dataset_vector_index_task.py
  44. 6
    16
      api/tasks/delete_segment_from_index_task.py
  45. 76
    0
      api/tasks/disable_segments_from_index_task.py
  46. 1
    1
      api/tasks/document_indexing_sync_task.py
  47. 1
    1
      api/tasks/document_indexing_update_task.py
  48. 3
    3
      api/tasks/duplicate_document_indexing_task.py
  49. 18
    1
      api/tasks/enable_segment_to_index_task.py
  50. 108
    0
      api/tasks/enable_segments_to_index_task.py
  51. 1
    1
      api/tasks/remove_document_from_index_task.py
  52. 7
    7
      api/tasks/retry_document_indexing_task.py
  53. 7
    7
      api/tasks/sync_website_document_indexing_task.py
  54. 98
    0
      api/templates/clean_document_job_mail_template-US.html

+ 1
- 1
api/controllers/console/datasets/data_source.py Parādīt failu

@@ -218,7 +218,7 @@ class DataSourceNotionApi(Resource):
args["doc_form"],
args["doc_language"],
)
return response, 200
return response.model_dump(), 200


class DataSourceNotionDatasetSyncApi(Resource):

+ 14
- 1
api/controllers/console/datasets/datasets.py Parādīt failu

@@ -464,7 +464,7 @@ class DatasetIndexingEstimateApi(Resource):
except Exception as e:
raise IndexingEstimateError(str(e))

return response, 200
return response.model_dump(), 200


class DatasetRelatedAppListApi(Resource):
@@ -733,6 +733,18 @@ class DatasetPermissionUserListApi(Resource):
}, 200


class DatasetAutoDisableLogApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
return DatasetService.get_dataset_auto_disable_logs(dataset_id_str), 200


api.add_resource(DatasetListApi, "/datasets")
api.add_resource(DatasetApi, "/datasets/<uuid:dataset_id>")
api.add_resource(DatasetUseCheckApi, "/datasets/<uuid:dataset_id>/use-check")
@@ -747,3 +759,4 @@ api.add_resource(DatasetApiBaseUrlApi, "/datasets/api-base-info")
api.add_resource(DatasetRetrievalSettingApi, "/datasets/retrieval-setting")
api.add_resource(DatasetRetrievalSettingMockApi, "/datasets/retrieval-setting/<string:vector_type>")
api.add_resource(DatasetPermissionUserListApi, "/datasets/<uuid:dataset_id>/permission-part-users")
api.add_resource(DatasetAutoDisableLogApi, "/datasets/<uuid:dataset_id>/auto-disable-logs")

+ 99
- 79
api/controllers/console/datasets/datasets_document.py Parādīt failu

@@ -52,6 +52,7 @@ from fields.document_fields import (
from libs.login import login_required
from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile
from services.dataset_service import DatasetService, DocumentService
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig
from tasks.add_document_to_index_task import add_document_to_index_task
from tasks.remove_document_from_index_task import remove_document_from_index_task

@@ -255,20 +256,22 @@ class DatasetDocumentListApi(Resource):
parser.add_argument("duplicate", type=bool, default=True, nullable=False, location="json")
parser.add_argument("original_document_id", type=str, required=False, location="json")
parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")

parser.add_argument(
"doc_language", type=str, default="English", required=False, nullable=False, location="json"
)
parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
args = parser.parse_args()
knowledge_config = KnowledgeConfig(**args)

if not dataset.indexing_technique and not args["indexing_technique"]:
if not dataset.indexing_technique and not knowledge_config.indexing_technique:
raise ValueError("indexing_technique is required.")

# validate args
DocumentService.document_create_args_validate(args)
DocumentService.document_create_args_validate(knowledge_config)

try:
documents, batch = DocumentService.save_document_with_dataset_id(dataset, args, current_user)
documents, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, current_user)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
@@ -278,6 +281,25 @@ class DatasetDocumentListApi(Resource):

return {"documents": documents, "batch": batch}

@setup_required
@login_required
@account_initialization_required
def delete(self, dataset_id):
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if dataset is None:
raise NotFound("Dataset not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)

try:
document_ids = request.args.getlist("document_id")
DocumentService.delete_documents(dataset, document_ids)
except services.errors.document.DocumentIndexingError:
raise DocumentIndexingError("Cannot delete document during indexing.")

return {"result": "success"}, 204


class DatasetInitApi(Resource):
@setup_required
@@ -313,9 +335,9 @@ class DatasetInitApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
if not current_user.is_dataset_editor:
raise Forbidden()
if args["indexing_technique"] == "high_quality":
if args["embedding_model"] is None or args["embedding_model_provider"] is None:
knowledge_config = KnowledgeConfig(**args)
if knowledge_config.indexing_technique == "high_quality":
if knowledge_config.embedding_model is None or knowledge_config.embedding_model_provider is None:
raise ValueError("embedding model and embedding model provider are required for high quality indexing.")
try:
model_manager = ModelManager()
@@ -334,11 +356,11 @@ class DatasetInitApi(Resource):
raise ProviderNotInitializeError(ex.description)

# validate args
DocumentService.document_create_args_validate(args)
DocumentService.document_create_args_validate(knowledge_config)

try:
dataset, documents, batch = DocumentService.save_document_without_dataset_id(
tenant_id=current_user.current_tenant_id, document_data=args, account=current_user
tenant_id=current_user.current_tenant_id, knowledge_config=knowledge_config, account=current_user
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
@@ -409,7 +431,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
except Exception as e:
raise IndexingEstimateError(str(e))

return response
return response.model_dump(), 200


class DocumentBatchIndexingEstimateApi(DocumentResource):
@@ -422,7 +444,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
documents = self.get_batch_documents(dataset_id, batch)
response = {"tokens": 0, "total_price": 0, "currency": "USD", "total_segments": 0, "preview": []}
if not documents:
return response
return response, 200
data_process_rule = documents[0].dataset_process_rule
data_process_rule_dict = data_process_rule.to_dict()
info_list = []
@@ -509,7 +531,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
raise ProviderNotInitializeError(ex.description)
except Exception as e:
raise IndexingEstimateError(str(e))
return response
return response.model_dump(), 200


class DocumentBatchIndexingStatusApi(DocumentResource):
@@ -582,7 +604,8 @@ class DocumentDetailApi(DocumentResource):
if metadata == "only":
response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata}
elif metadata == "without":
process_rules = DatasetService.get_process_rules(dataset_id)
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
document_process_rules = document.dataset_process_rule.to_dict()
data_source_info = document.data_source_detail_dict
response = {
"id": document.id,
@@ -590,7 +613,8 @@ class DocumentDetailApi(DocumentResource):
"data_source_type": document.data_source_type,
"data_source_info": data_source_info,
"dataset_process_rule_id": document.dataset_process_rule_id,
"dataset_process_rule": process_rules,
"dataset_process_rule": dataset_process_rules,
"document_process_rule": document_process_rules,
"name": document.name,
"created_from": document.created_from,
"created_by": document.created_by,
@@ -613,7 +637,8 @@ class DocumentDetailApi(DocumentResource):
"doc_language": document.doc_language,
}
else:
process_rules = DatasetService.get_process_rules(dataset_id)
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
document_process_rules = document.dataset_process_rule.to_dict()
data_source_info = document.data_source_detail_dict
response = {
"id": document.id,
@@ -621,7 +646,8 @@ class DocumentDetailApi(DocumentResource):
"data_source_type": document.data_source_type,
"data_source_info": data_source_info,
"dataset_process_rule_id": document.dataset_process_rule_id,
"dataset_process_rule": process_rules,
"dataset_process_rule": dataset_process_rules,
"document_process_rule": document_process_rules,
"name": document.name,
"created_from": document.created_from,
"created_by": document.created_by,
@@ -757,9 +783,8 @@ class DocumentStatusApi(DocumentResource):
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("vector_space")
def patch(self, dataset_id, document_id, action):
def patch(self, dataset_id, action):
dataset_id = str(dataset_id)
document_id = str(document_id)
dataset = DatasetService.get_dataset(dataset_id)
if dataset is None:
raise NotFound("Dataset not found.")
@@ -774,84 +799,79 @@ class DocumentStatusApi(DocumentResource):
# check user's permission
DatasetService.check_dataset_permission(dataset, current_user)

document = self.get_document(dataset_id, document_id)
document_ids = request.args.getlist("document_id")
for document_id in document_ids:
document = self.get_document(dataset_id, document_id)

indexing_cache_key = "document_{}_indexing".format(document.id)
cache_result = redis_client.get(indexing_cache_key)
if cache_result is not None:
raise InvalidActionError("Document is being indexed, please try again later")
indexing_cache_key = "document_{}_indexing".format(document.id)
cache_result = redis_client.get(indexing_cache_key)
if cache_result is not None:
raise InvalidActionError(f"Document:{document.name} is being indexed, please try again later")

if action == "enable":
if document.enabled:
raise InvalidActionError("Document already enabled.")
if action == "enable":
if document.enabled:
continue
document.enabled = True
document.disabled_at = None
document.disabled_by = None
document.updated_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit()

document.enabled = True
document.disabled_at = None
document.disabled_by = None
document.updated_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit()
# Set cache to prevent indexing the same document multiple times
redis_client.setex(indexing_cache_key, 600, 1)

# Set cache to prevent indexing the same document multiple times
redis_client.setex(indexing_cache_key, 600, 1)
add_document_to_index_task.delay(document_id)

add_document_to_index_task.delay(document_id)
elif action == "disable":
if not document.completed_at or document.indexing_status != "completed":
raise InvalidActionError(f"Document: {document.name} is not completed.")
if not document.enabled:
continue

return {"result": "success"}, 200
document.enabled = False
document.disabled_at = datetime.now(UTC).replace(tzinfo=None)
document.disabled_by = current_user.id
document.updated_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit()

elif action == "disable":
if not document.completed_at or document.indexing_status != "completed":
raise InvalidActionError("Document is not completed.")
if not document.enabled:
raise InvalidActionError("Document already disabled.")
# Set cache to prevent indexing the same document multiple times
redis_client.setex(indexing_cache_key, 600, 1)

document.enabled = False
document.disabled_at = datetime.now(UTC).replace(tzinfo=None)
document.disabled_by = current_user.id
document.updated_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit()
remove_document_from_index_task.delay(document_id)

# Set cache to prevent indexing the same document multiple times
redis_client.setex(indexing_cache_key, 600, 1)
elif action == "archive":
if document.archived:
continue

remove_document_from_index_task.delay(document_id)
document.archived = True
document.archived_at = datetime.now(UTC).replace(tzinfo=None)
document.archived_by = current_user.id
document.updated_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit()

return {"result": "success"}, 200
if document.enabled:
# Set cache to prevent indexing the same document multiple times
redis_client.setex(indexing_cache_key, 600, 1)

elif action == "archive":
if document.archived:
raise InvalidActionError("Document already archived.")
remove_document_from_index_task.delay(document_id)

document.archived = True
document.archived_at = datetime.now(UTC).replace(tzinfo=None)
document.archived_by = current_user.id
document.updated_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit()
elif action == "un_archive":
if not document.archived:
continue
document.archived = False
document.archived_at = None
document.archived_by = None
document.updated_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit()

if document.enabled:
# Set cache to prevent indexing the same document multiple times
redis_client.setex(indexing_cache_key, 600, 1)

remove_document_from_index_task.delay(document_id)

return {"result": "success"}, 200
elif action == "un_archive":
if not document.archived:
raise InvalidActionError("Document is not archived.")

document.archived = False
document.archived_at = None
document.archived_by = None
document.updated_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit()

# Set cache to prevent indexing the same document multiple times
redis_client.setex(indexing_cache_key, 600, 1)
add_document_to_index_task.delay(document_id)

add_document_to_index_task.delay(document_id)

return {"result": "success"}, 200
else:
raise InvalidActionError()
else:
raise InvalidActionError()
return {"result": "success"}, 200


class DocumentPauseApi(DocumentResource):
@@ -1022,7 +1042,7 @@ api.add_resource(
)
api.add_resource(DocumentDeleteApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>")
api.add_resource(DocumentMetadataApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/metadata")
api.add_resource(DocumentStatusApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/status/<string:action>")
api.add_resource(DocumentStatusApi, "/datasets/<uuid:dataset_id>/documents/status/<string:action>/batch")
api.add_resource(DocumentPauseApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/pause")
api.add_resource(DocumentRecoverApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/resume")
api.add_resource(DocumentRetryApi, "/datasets/<uuid:dataset_id>/retry")

+ 317
- 82
api/controllers/console/datasets/datasets_segments.py Parādīt failu

@@ -1,5 +1,4 @@
import uuid
from datetime import UTC, datetime

import pandas as pd
from flask import request
@@ -10,7 +9,13 @@ from werkzeug.exceptions import Forbidden, NotFound
import services
from controllers.console import api
from controllers.console.app.error import ProviderNotInitializeError
from controllers.console.datasets.error import InvalidActionError, NoFileUploadedError, TooManyFilesError
from controllers.console.datasets.error import (
ChildChunkDeleteIndexError,
ChildChunkIndexingError,
InvalidActionError,
NoFileUploadedError,
TooManyFilesError,
)
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_knowledge_limit_check,
@@ -20,15 +25,15 @@ from controllers.console.wraps import (
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from fields.segment_fields import segment_fields
from fields.segment_fields import child_chunk_fields, segment_fields
from libs.login import login_required
from models import DocumentSegment
from models.dataset import ChildChunk, DocumentSegment
from services.dataset_service import DatasetService, DocumentService, SegmentService
from services.entities.knowledge_entities.knowledge_entities import ChildChunkUpdateArgs, SegmentUpdateArgs
from services.errors.chunk import ChildChunkDeleteIndexError as ChildChunkDeleteIndexServiceError
from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingServiceError
from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task
from tasks.disable_segment_from_index_task import disable_segment_from_index_task
from tasks.enable_segment_to_index_task import enable_segment_to_index_task


class DatasetDocumentSegmentListApi(Resource):
@@ -53,15 +58,16 @@ class DatasetDocumentSegmentListApi(Resource):
raise NotFound("Document not found.")

parser = reqparse.RequestParser()
parser.add_argument("last_id", type=str, default=None, location="args")
parser.add_argument("limit", type=int, default=20, location="args")
parser.add_argument("status", type=str, action="append", default=[], location="args")
parser.add_argument("hit_count_gte", type=int, default=None, location="args")
parser.add_argument("enabled", type=str, default="all", location="args")
parser.add_argument("keyword", type=str, default=None, location="args")
parser.add_argument("page", type=int, default=1, location="args")

args = parser.parse_args()

last_id = args["last_id"]
page = args["page"]
limit = min(args["limit"], 100)
status_list = args["status"]
hit_count_gte = args["hit_count_gte"]
@@ -69,14 +75,7 @@ class DatasetDocumentSegmentListApi(Resource):

query = DocumentSegment.query.filter(
DocumentSegment.document_id == str(document_id), DocumentSegment.tenant_id == current_user.current_tenant_id
)

if last_id is not None:
last_segment = db.session.get(DocumentSegment, str(last_id))
if last_segment:
query = query.filter(DocumentSegment.position > last_segment.position)
else:
return {"data": [], "has_more": False, "limit": limit}, 200
).order_by(DocumentSegment.position.asc())

if status_list:
query = query.filter(DocumentSegment.status.in_(status_list))
@@ -93,21 +92,44 @@ class DatasetDocumentSegmentListApi(Resource):
elif args["enabled"].lower() == "false":
query = query.filter(DocumentSegment.enabled == False)

total = query.count()
segments = query.order_by(DocumentSegment.position).limit(limit + 1).all()

has_more = False
if len(segments) > limit:
has_more = True
segments = segments[:-1]
segments = query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False)

return {
"data": marshal(segments, segment_fields),
"doc_form": document.doc_form,
"has_more": has_more,
response = {
"data": marshal(segments.items, segment_fields),
"limit": limit,
"total": total,
}, 200
"total": segments.total,
"total_pages": segments.pages,
"page": page,
}
return response, 200

@setup_required
@login_required
@account_initialization_required
def delete(self, dataset_id, document_id):
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
# check document
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound("Document not found.")
segment_ids = request.args.getlist("segment_id")

# The role of the current user in the ta table must be admin or owner
if not current_user.is_editor:
raise Forbidden()
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
SegmentService.delete_segments(segment_ids, document, dataset)
return {"result": "success"}, 200


class DatasetDocumentSegmentApi(Resource):
@@ -115,11 +137,15 @@ class DatasetDocumentSegmentApi(Resource):
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("vector_space")
def patch(self, dataset_id, segment_id, action):
def patch(self, dataset_id, document_id, action):
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound("Document not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
# The role of the current user in the ta table must be admin, owner, or editor
@@ -147,59 +173,17 @@ class DatasetDocumentSegmentApi(Resource):
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
segment_ids = request.args.getlist("segment_id")

segment = DocumentSegment.query.filter(
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
).first()

if not segment:
raise NotFound("Segment not found.")

if segment.status != "completed":
raise NotFound("Segment is not completed, enable or disable function is not allowed")

document_indexing_cache_key = "document_{}_indexing".format(segment.document_id)
document_indexing_cache_key = "document_{}_indexing".format(document.id)
cache_result = redis_client.get(document_indexing_cache_key)
if cache_result is not None:
raise InvalidActionError("Document is being indexed, please try again later")

indexing_cache_key = "segment_{}_indexing".format(segment.id)
cache_result = redis_client.get(indexing_cache_key)
if cache_result is not None:
raise InvalidActionError("Segment is being indexed, please try again later")

if action == "enable":
if segment.enabled:
raise InvalidActionError("Segment is already enabled.")

segment.enabled = True
segment.disabled_at = None
segment.disabled_by = None
db.session.commit()

# Set cache to prevent indexing the same segment multiple times
redis_client.setex(indexing_cache_key, 600, 1)

enable_segment_to_index_task.delay(segment.id)

return {"result": "success"}, 200
elif action == "disable":
if not segment.enabled:
raise InvalidActionError("Segment is already disabled.")

segment.enabled = False
segment.disabled_at = datetime.now(UTC).replace(tzinfo=None)
segment.disabled_by = current_user.id
db.session.commit()

# Set cache to prevent indexing the same segment multiple times
redis_client.setex(indexing_cache_key, 600, 1)

disable_segment_from_index_task.delay(segment.id)

return {"result": "success"}, 200
else:
raise InvalidActionError()
try:
SegmentService.update_segments_status(segment_ids, action, dataset, document)
except Exception as e:
raise InvalidActionError(str(e))
return {"result": "success"}, 200


class DatasetDocumentSegmentAddApi(Resource):
@@ -307,9 +291,12 @@ class DatasetDocumentSegmentUpdateApi(Resource):
parser.add_argument("content", type=str, required=True, nullable=False, location="json")
parser.add_argument("answer", type=str, required=False, nullable=True, location="json")
parser.add_argument("keywords", type=list, required=False, nullable=True, location="json")
parser.add_argument(
"regenerate_child_chunks", type=bool, required=False, nullable=True, default=False, location="json"
)
args = parser.parse_args()
SegmentService.segment_create_args_validate(args, document)
segment = SegmentService.update_segment(args, segment, document, dataset)
segment = SegmentService.update_segment(SegmentUpdateArgs(**args), segment, document, dataset)
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200

@setup_required
@@ -412,8 +399,248 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
return {"job_id": job_id, "job_status": cache_result.decode()}, 200


class ChildChunkAddApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_knowledge_limit_check("add_segment")
def post(self, dataset_id, document_id, segment_id):
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
# check document
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound("Document not found.")
# check segment
segment_id = str(segment_id)
segment = DocumentSegment.query.filter(
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
).first()
if not segment:
raise NotFound("Segment not found.")
if not current_user.is_editor:
raise Forbidden()
# check embedding model setting
if dataset.indexing_technique == "high_quality":
try:
model_manager = ModelManager()
model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider "
"in the Settings -> Model Provider."
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
# validate args
parser = reqparse.RequestParser()
parser.add_argument("content", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
try:
child_chunk = SegmentService.create_child_chunk(args.get("content"), segment, document, dataset)
except ChildChunkIndexingServiceError as e:
raise ChildChunkIndexingError(str(e))
return {"data": marshal(child_chunk, child_chunk_fields)}, 200

@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id, document_id, segment_id):
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
# check document
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound("Document not found.")
# check segment
segment_id = str(segment_id)
segment = DocumentSegment.query.filter(
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
).first()
if not segment:
raise NotFound("Segment not found.")
parser = reqparse.RequestParser()
parser.add_argument("limit", type=int, default=20, location="args")
parser.add_argument("keyword", type=str, default=None, location="args")
parser.add_argument("page", type=int, default=1, location="args")

args = parser.parse_args()

page = args["page"]
limit = min(args["limit"], 100)
keyword = args["keyword"]

child_chunks = SegmentService.get_child_chunks(segment_id, document_id, dataset_id, page, limit, keyword)
return {
"data": marshal(child_chunks.items, child_chunk_fields),
"total": child_chunks.total,
"total_pages": child_chunks.pages,
"page": page,
"limit": limit,
}, 200

@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("vector_space")
def patch(self, dataset_id, document_id, segment_id):
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
# check document
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound("Document not found.")
# check segment
segment_id = str(segment_id)
segment = DocumentSegment.query.filter(
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
).first()
if not segment:
raise NotFound("Segment not found.")
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
# validate args
parser = reqparse.RequestParser()
parser.add_argument("chunks", type=list, required=True, nullable=False, location="json")
args = parser.parse_args()
try:
chunks = [ChildChunkUpdateArgs(**chunk) for chunk in args.get("chunks")]
child_chunks = SegmentService.update_child_chunks(chunks, segment, document, dataset)
except ChildChunkIndexingServiceError as e:
raise ChildChunkIndexingError(str(e))
return {"data": marshal(child_chunks, child_chunk_fields)}, 200


class ChildChunkUpdateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def delete(self, dataset_id, document_id, segment_id, child_chunk_id):
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
# check document
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound("Document not found.")
# check segment
segment_id = str(segment_id)
segment = DocumentSegment.query.filter(
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
).first()
if not segment:
raise NotFound("Segment not found.")
# check child chunk
child_chunk_id = str(child_chunk_id)
child_chunk = ChildChunk.query.filter(
ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id
).first()
if not child_chunk:
raise NotFound("Child chunk not found.")
# The role of the current user in the ta table must be admin or owner
if not current_user.is_editor:
raise Forbidden()
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
try:
SegmentService.delete_child_chunk(child_chunk, dataset)
except ChildChunkDeleteIndexServiceError as e:
raise ChildChunkDeleteIndexError(str(e))
return {"result": "success"}, 200

@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("vector_space")
def patch(self, dataset_id, document_id, segment_id, child_chunk_id):
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
# check document
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound("Document not found.")
# check segment
segment_id = str(segment_id)
segment = DocumentSegment.query.filter(
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
).first()
if not segment:
raise NotFound("Segment not found.")
# check child chunk
child_chunk_id = str(child_chunk_id)
child_chunk = ChildChunk.query.filter(
ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id
).first()
if not child_chunk:
raise NotFound("Child chunk not found.")
# The role of the current user in the ta table must be admin or owner
if not current_user.is_editor:
raise Forbidden()
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
# validate args
parser = reqparse.RequestParser()
parser.add_argument("content", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
try:
child_chunk = SegmentService.update_child_chunk(
args.get("content"), child_chunk, segment, document, dataset
)
except ChildChunkIndexingServiceError as e:
raise ChildChunkIndexingError(str(e))
return {"data": marshal(child_chunk, child_chunk_fields)}, 200


api.add_resource(DatasetDocumentSegmentListApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments")
api.add_resource(DatasetDocumentSegmentApi, "/datasets/<uuid:dataset_id>/segments/<uuid:segment_id>/<string:action>")
api.add_resource(
DatasetDocumentSegmentApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment/<string:action>"
)
api.add_resource(DatasetDocumentSegmentAddApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment")
api.add_resource(
DatasetDocumentSegmentUpdateApi,
@@ -424,3 +651,11 @@ api.add_resource(
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/batch_import",
"/datasets/batch_import_status/<uuid:job_id>",
)
api.add_resource(
ChildChunkAddApi,
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>/child_chunks",
)
api.add_resource(
ChildChunkUpdateApi,
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>/child_chunks/<uuid:child_chunk_id>",
)

+ 12
- 0
api/controllers/console/datasets/error.py Parādīt failu

@@ -89,3 +89,15 @@ class IndexingEstimateError(BaseHTTPException):
error_code = "indexing_estimate_error"
description = "Knowledge indexing estimate failed: {message}"
code = 500


class ChildChunkIndexingError(BaseHTTPException):
error_code = "child_chunk_indexing_error"
description = "Create child chunk index failed: {message}"
code = 500


class ChildChunkDeleteIndexError(BaseHTTPException):
error_code = "child_chunk_delete_index_error"
description = "Delete child chunk index failed: {message}"
code = 500

+ 2
- 1
api/controllers/service_api/dataset/segment.py Parādīt failu

@@ -16,6 +16,7 @@ from extensions.ext_database import db
from fields.segment_fields import segment_fields
from models.dataset import Dataset, DocumentSegment
from services.dataset_service import DatasetService, DocumentService, SegmentService
from services.entities.knowledge_entities.knowledge_entities import SegmentUpdateArgs


class SegmentApi(DatasetApiResource):
@@ -193,7 +194,7 @@ class DatasetSegmentApi(DatasetApiResource):
args = parser.parse_args()

SegmentService.segment_create_args_validate(args["segment"], document)
segment = SegmentService.update_segment(args["segment"], segment, document, dataset)
segment = SegmentService.update_segment(SegmentUpdateArgs(**args["segment"]), segment, document, dataset)
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200



+ 19
- 0
api/core/entities/knowledge_entities.py Parādīt failu

@@ -0,0 +1,19 @@
from typing import Optional

from pydantic import BaseModel


class PreviewDetail(BaseModel):
content: str
child_chunks: Optional[list[str]] = None


class QAPreviewDetail(BaseModel):
question: str
answer: str


class IndexingEstimate(BaseModel):
total_segments: int
preview: list[PreviewDetail]
qa_preview: Optional[list[QAPreviewDetail]] = None

+ 68
- 212
api/core/indexing_runner.py Parādīt failu

@@ -8,34 +8,34 @@ import time
import uuid
from typing import Any, Optional, cast

from flask import Flask, current_app
from flask import current_app
from flask_login import current_user # type: ignore
from sqlalchemy.orm.exc import ObjectDeletedError

from configs import dify_config
from core.entities.knowledge_entities import IndexingEstimate, PreviewDetail, QAPreviewDetail
from core.errors.error import ProviderTokenNotInitError
from core.llm_generator.llm_generator import LLMGenerator
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.rag.cleaner.clean_processor import CleanProcessor
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import Document
from core.rag.models.document import ChildDocument, Document
from core.rag.splitter.fixed_text_splitter import (
EnhanceRecursiveCharacterTextSplitter,
FixedRecursiveCharacterTextSplitter,
)
from core.rag.splitter.text_splitter import TextSplitter
from core.tools.utils.text_processing_utils import remove_leading_symbols
from core.tools.utils.web_reader_tool import get_image_upload_file_ids
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from extensions.ext_storage import storage
from libs import helper
from models.dataset import Dataset, DatasetProcessRule, DocumentSegment
from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment
from models.dataset import Document as DatasetDocument
from models.model import UploadFile
from services.feature_service import FeatureService
@@ -115,6 +115,9 @@ class IndexingRunner:

for document_segment in document_segments:
db.session.delete(document_segment)
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
# delete child chunks
db.session.query(ChildChunk).filter(ChildChunk.segment_id == document_segment.id).delete()
db.session.commit()
# get the process rule
processing_rule = (
@@ -183,7 +186,22 @@ class IndexingRunner:
"dataset_id": document_segment.dataset_id,
},
)

if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
child_chunks = document_segment.child_chunks
if child_chunks:
child_documents = []
for child_chunk in child_chunks:
child_document = ChildDocument(
page_content=child_chunk.content,
metadata={
"doc_id": child_chunk.index_node_id,
"doc_hash": child_chunk.index_node_hash,
"document_id": document_segment.document_id,
"dataset_id": document_segment.dataset_id,
},
)
child_documents.append(child_document)
document.children = child_documents
documents.append(document)

# build index
@@ -222,7 +240,7 @@ class IndexingRunner:
doc_language: str = "English",
dataset_id: Optional[str] = None,
indexing_technique: str = "economy",
) -> dict:
) -> IndexingEstimate:
"""
Estimate the indexing for the document.
"""
@@ -258,31 +276,38 @@ class IndexingRunner:
tenant_id=tenant_id,
model_type=ModelType.TEXT_EMBEDDING,
)
preview_texts: list[str] = []
preview_texts = []

total_segments = 0
index_type = doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
all_text_docs = []
for extract_setting in extract_settings:
# extract
text_docs = index_processor.extract(extract_setting, process_rule_mode=tmp_processing_rule["mode"])
all_text_docs.extend(text_docs)
processing_rule = DatasetProcessRule(
mode=tmp_processing_rule["mode"], rules=json.dumps(tmp_processing_rule["rules"])
)

# get splitter
splitter = self._get_splitter(processing_rule, embedding_model_instance)

# split to documents
documents = self._split_to_documents_for_estimate(
text_docs=text_docs, splitter=splitter, processing_rule=processing_rule
text_docs = index_processor.extract(extract_setting, process_rule_mode=tmp_processing_rule["mode"])
documents = index_processor.transform(
text_docs,
embedding_model_instance=embedding_model_instance,
process_rule=processing_rule.to_dict(),
tenant_id=current_user.current_tenant_id,
doc_language=doc_language,
preview=True,
)

total_segments += len(documents)
for document in documents:
if len(preview_texts) < 5:
preview_texts.append(document.page_content)
if len(preview_texts) < 10:
if doc_form and doc_form == "qa_model":
preview_detail = QAPreviewDetail(
question=document.page_content, answer=document.metadata.get("answer")
)
preview_texts.append(preview_detail)
else:
preview_detail = PreviewDetail(content=document.page_content)
if document.children:
preview_detail.child_chunks = [child.page_content for child in document.children]
preview_texts.append(preview_detail)

# delete image files and related db records
image_upload_file_ids = get_image_upload_file_ids(document.page_content)
@@ -299,15 +324,8 @@ class IndexingRunner:
db.session.delete(image_file)

if doc_form and doc_form == "qa_model":
if len(preview_texts) > 0:
# qa model document
response = LLMGenerator.generate_qa_document(
current_user.current_tenant_id, preview_texts[0], doc_language
)
document_qa_list = self.format_split_text(response)

return {"total_segments": total_segments * 20, "qa_preview": document_qa_list, "preview": preview_texts}
return {"total_segments": total_segments, "preview": preview_texts}
return IndexingEstimate(total_segments=total_segments * 20, qa_preview=preview_texts, preview=[])
return IndexingEstimate(total_segments=total_segments, preview=preview_texts)

def _extract(
self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict
@@ -401,31 +419,26 @@ class IndexingRunner:

@staticmethod
def _get_splitter(
processing_rule: DatasetProcessRule, embedding_model_instance: Optional[ModelInstance]
processing_rule_mode: str,
max_tokens: int,
chunk_overlap: int,
separator: str,
embedding_model_instance: Optional[ModelInstance],
) -> TextSplitter:
"""
Get the NodeParser object according to the processing rule.
"""
character_splitter: TextSplitter
if processing_rule.mode == "custom":
if processing_rule_mode in ["custom", "hierarchical"]:
# The user-defined segmentation rule
rules = json.loads(processing_rule.rules)
segmentation = rules["segmentation"]
max_segmentation_tokens_length = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH
if segmentation["max_tokens"] < 50 or segmentation["max_tokens"] > max_segmentation_tokens_length:
if max_tokens < 50 or max_tokens > max_segmentation_tokens_length:
raise ValueError(f"Custom segment length should be between 50 and {max_segmentation_tokens_length}.")

separator = segmentation["separator"]
if separator:
separator = separator.replace("\\n", "\n")

if segmentation.get("chunk_overlap"):
chunk_overlap = segmentation["chunk_overlap"]
else:
chunk_overlap = 0

character_splitter = FixedRecursiveCharacterTextSplitter.from_encoder(
chunk_size=segmentation["max_tokens"],
chunk_size=max_tokens,
chunk_overlap=chunk_overlap,
fixed_separator=separator,
separators=["\n\n", "。", ". ", " ", ""],
@@ -443,142 +456,6 @@ class IndexingRunner:

return character_splitter

def _step_split(
self,
text_docs: list[Document],
splitter: TextSplitter,
dataset: Dataset,
dataset_document: DatasetDocument,
processing_rule: DatasetProcessRule,
) -> list[Document]:
"""
Split the text documents into documents and save them to the document segment.
"""
documents = self._split_to_documents(
text_docs=text_docs,
splitter=splitter,
processing_rule=processing_rule,
tenant_id=dataset.tenant_id,
document_form=dataset_document.doc_form,
document_language=dataset_document.doc_language,
)

# save node to document segment
doc_store = DatasetDocumentStore(
dataset=dataset, user_id=dataset_document.created_by, document_id=dataset_document.id
)

# add document segments
doc_store.add_documents(documents)

# update document status to indexing
cur_time = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
self._update_document_index_status(
document_id=dataset_document.id,
after_indexing_status="indexing",
extra_update_params={
DatasetDocument.cleaning_completed_at: cur_time,
DatasetDocument.splitting_completed_at: cur_time,
},
)

# update segment status to indexing
self._update_segments_by_document(
dataset_document_id=dataset_document.id,
update_params={
DocumentSegment.status: "indexing",
DocumentSegment.indexing_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
},
)

return documents

def _split_to_documents(
self,
text_docs: list[Document],
splitter: TextSplitter,
processing_rule: DatasetProcessRule,
tenant_id: str,
document_form: str,
document_language: str,
) -> list[Document]:
"""
Split the text documents into nodes.
"""
all_documents: list[Document] = []
all_qa_documents: list[Document] = []
for text_doc in text_docs:
# document clean
document_text = self._document_clean(text_doc.page_content, processing_rule)
text_doc.page_content = document_text

# parse document to nodes
documents = splitter.split_documents([text_doc])
split_documents = []
for document_node in documents:
if document_node.page_content.strip():
if document_node.metadata is not None:
doc_id = str(uuid.uuid4())
hash = helper.generate_text_hash(document_node.page_content)
document_node.metadata["doc_id"] = doc_id
document_node.metadata["doc_hash"] = hash
# delete Splitter character
page_content = document_node.page_content
document_node.page_content = remove_leading_symbols(page_content)

if document_node.page_content:
split_documents.append(document_node)
all_documents.extend(split_documents)
# processing qa document
if document_form == "qa_model":
for i in range(0, len(all_documents), 10):
threads = []
sub_documents = all_documents[i : i + 10]
for doc in sub_documents:
document_format_thread = threading.Thread(
target=self.format_qa_document,
kwargs={
"flask_app": current_app._get_current_object(), # type: ignore
"tenant_id": tenant_id,
"document_node": doc,
"all_qa_documents": all_qa_documents,
"document_language": document_language,
},
)
threads.append(document_format_thread)
document_format_thread.start()
for thread in threads:
thread.join()
return all_qa_documents
return all_documents

def format_qa_document(self, flask_app: Flask, tenant_id: str, document_node, all_qa_documents, document_language):
format_documents = []
if document_node.page_content is None or not document_node.page_content.strip():
return
with flask_app.app_context():
try:
# qa model document
response = LLMGenerator.generate_qa_document(tenant_id, document_node.page_content, document_language)
document_qa_list = self.format_split_text(response)
qa_documents = []
for result in document_qa_list:
qa_document = Document(
page_content=result["question"], metadata=document_node.metadata.model_copy()
)
if qa_document.metadata is not None:
doc_id = str(uuid.uuid4())
hash = helper.generate_text_hash(result["question"])
qa_document.metadata["answer"] = result["answer"]
qa_document.metadata["doc_id"] = doc_id
qa_document.metadata["doc_hash"] = hash
qa_documents.append(qa_document)
format_documents.extend(qa_documents)
except Exception as e:
logging.exception("Failed to format qa document")

all_qa_documents.extend(format_documents)

def _split_to_documents_for_estimate(
self, text_docs: list[Document], splitter: TextSplitter, processing_rule: DatasetProcessRule
) -> list[Document]:
@@ -624,11 +501,11 @@ class IndexingRunner:
return document_text

@staticmethod
def format_split_text(text):
def format_split_text(text: str) -> list[QAPreviewDetail]:
regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q\d+:|$)"
matches = re.findall(regex, text, re.UNICODE)

return [{"question": q, "answer": re.sub(r"\n\s*", "\n", a.strip())} for q, a in matches if q and a]
return [QAPreviewDetail(question=q, answer=re.sub(r"\n\s*", "\n", a.strip())) for q, a in matches if q and a]

def _load(
self,
@@ -654,13 +531,14 @@ class IndexingRunner:
indexing_start_at = time.perf_counter()
tokens = 0
chunk_size = 10
if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX:
# create keyword index
create_keyword_thread = threading.Thread(
target=self._process_keyword_index,
args=(current_app._get_current_object(), dataset.id, dataset_document.id, documents),
)
create_keyword_thread.start()

# create keyword index
create_keyword_thread = threading.Thread(
target=self._process_keyword_index,
args=(current_app._get_current_object(), dataset.id, dataset_document.id, documents), # type: ignore
)
create_keyword_thread.start()
if dataset.indexing_technique == "high_quality":
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
futures = []
@@ -680,8 +558,8 @@ class IndexingRunner:

for future in futures:
tokens += future.result()
create_keyword_thread.join()
if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX:
create_keyword_thread.join()
indexing_end_at = time.perf_counter()

# update document status to completed
@@ -793,28 +671,6 @@ class IndexingRunner:
DocumentSegment.query.filter_by(document_id=dataset_document_id).update(update_params)
db.session.commit()

@staticmethod
def batch_add_segments(segments: list[DocumentSegment], dataset: Dataset):
"""
Batch add segments index processing
"""
documents = []
for segment in segments:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
)
documents.append(document)
# save vector index
index_type = dataset.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
index_processor.load(dataset, documents)

def _transform(
self,
index_processor: BaseIndexProcessor,
@@ -856,7 +712,7 @@ class IndexingRunner:
)

# add document segments
doc_store.add_documents(documents)
doc_store.add_documents(docs=documents, save_child=dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX)

# update document status to indexing
cur_time = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)

+ 89
- 1
api/core/rag/datasource/retrieval_service.py Parādīt failu

@@ -6,11 +6,14 @@ from flask import Flask, current_app
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.embedding.retrieval import RetrievalSegments
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.models.document import Document
from core.rag.rerank.rerank_type import RerankMode
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db
from models.dataset import Dataset
from models.dataset import ChildChunk, Dataset, DocumentSegment
from models.dataset import Document as DatasetDocument
from services.external_knowledge_service import ExternalDatasetService

default_retrieval_model = {
@@ -248,3 +251,88 @@ class RetrievalService:
@staticmethod
def escape_query_for_search(query: str) -> str:
return query.replace('"', '\\"')

@staticmethod
def format_retrieval_documents(documents: list[Document]) -> list[RetrievalSegments]:
records = []
include_segment_ids = []
segment_child_map = {}
for document in documents:
document_id = document.metadata["document_id"]
dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == document_id).first()
if dataset_document and dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
child_index_node_id = document.metadata["doc_id"]
result = (
db.session.query(ChildChunk, DocumentSegment)
.join(DocumentSegment, ChildChunk.segment_id == DocumentSegment.id)
.filter(
ChildChunk.index_node_id == child_index_node_id,
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
)
.first()
)
if result:
child_chunk, segment = result
if not segment:
continue
if segment.id not in include_segment_ids:
include_segment_ids.append(segment.id)
child_chunk_detail = {
"id": child_chunk.id,
"content": child_chunk.content,
"position": child_chunk.position,
"score": document.metadata.get("score", 0.0),
}
map_detail = {
"max_score": document.metadata.get("score", 0.0),
"child_chunks": [child_chunk_detail],
}
segment_child_map[segment.id] = map_detail
record = {
"segment": segment,
}
records.append(record)
else:
child_chunk_detail = {
"id": child_chunk.id,
"content": child_chunk.content,
"position": child_chunk.position,
"score": document.metadata.get("score", 0.0),
}
segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail)
segment_child_map[segment.id]["max_score"] = max(
segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0)
)
else:
continue
else:
index_node_id = document.metadata["doc_id"]

segment = (
db.session.query(DocumentSegment)
.filter(
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
DocumentSegment.index_node_id == index_node_id,
)
.first()
)

if not segment:
continue
include_segment_ids.append(segment.id)
record = {
"segment": segment,
"score": document.metadata.get("score", None),
}

records.append(record)
for record in records:
if record["segment"].id in segment_child_map:
record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks", None)
record["score"] = segment_child_map[record["segment"].id]["max_score"]

return [RetrievalSegments(**record) for record in records]

+ 43
- 2
api/core/rag/docstore/dataset_docstore.py Parādīt failu

@@ -7,7 +7,7 @@ from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.rag.models.document import Document
from extensions.ext_database import db
from models.dataset import Dataset, DocumentSegment
from models.dataset import ChildChunk, Dataset, DocumentSegment


class DatasetDocumentStore:
@@ -60,7 +60,7 @@ class DatasetDocumentStore:

return output

def add_documents(self, docs: Sequence[Document], allow_update: bool = True) -> None:
def add_documents(self, docs: Sequence[Document], allow_update: bool = True, save_child: bool = False) -> None:
max_position = (
db.session.query(func.max(DocumentSegment.position))
.filter(DocumentSegment.document_id == self._document_id)
@@ -120,6 +120,23 @@ class DatasetDocumentStore:
segment_document.answer = doc.metadata.pop("answer", "")

db.session.add(segment_document)
db.session.flush()
if save_child:
for postion, child in enumerate(doc.children, start=1):
child_segment = ChildChunk(
tenant_id=self._dataset.tenant_id,
dataset_id=self._dataset.id,
document_id=self._document_id,
segment_id=segment_document.id,
position=postion,
index_node_id=child.metadata["doc_id"],
index_node_hash=child.metadata["doc_hash"],
content=child.page_content,
word_count=len(child.page_content),
type="automatic",
created_by=self._user_id,
)
db.session.add(child_segment)
else:
segment_document.content = doc.page_content
if doc.metadata.get("answer"):
@@ -127,6 +144,30 @@ class DatasetDocumentStore:
segment_document.index_node_hash = doc.metadata["doc_hash"]
segment_document.word_count = len(doc.page_content)
segment_document.tokens = tokens
if save_child and doc.children:
# delete the existing child chunks
db.session.query(ChildChunk).filter(
ChildChunk.tenant_id == self._dataset.tenant_id,
ChildChunk.dataset_id == self._dataset.id,
ChildChunk.document_id == self._document_id,
ChildChunk.segment_id == segment_document.id,
).delete()
# add new child chunks
for position, child in enumerate(doc.children, start=1):
child_segment = ChildChunk(
tenant_id=self._dataset.tenant_id,
dataset_id=self._dataset.id,
document_id=self._document_id,
segment_id=segment_document.id,
position=position,
index_node_id=child.metadata["doc_id"],
index_node_hash=child.metadata["doc_hash"],
content=child.page_content,
word_count=len(child.page_content),
type="automatic",
created_by=self._user_id,
)
db.session.add(child_segment)

db.session.commit()


+ 23
- 0
api/core/rag/embedding/retrieval.py Parādīt failu

@@ -0,0 +1,23 @@
from typing import Optional

from pydantic import BaseModel

from models.dataset import DocumentSegment


class RetrievalChildChunk(BaseModel):
"""Retrieval segments."""

id: str
content: str
score: float
position: int


class RetrievalSegments(BaseModel):
"""Retrieval segments."""

model_config = {"arbitrary_types_allowed": True}
segment: DocumentSegment
child_chunks: Optional[list[RetrievalChildChunk]] = None
score: Optional[float] = None

+ 1
- 6
api/core/rag/extractor/extract_processor.py Parādīt failu

@@ -24,7 +24,6 @@ from core.rag.extractor.unstructured.unstructured_markdown_extractor import Unst
from core.rag.extractor.unstructured.unstructured_msg_extractor import UnstructuredMsgExtractor
from core.rag.extractor.unstructured.unstructured_ppt_extractor import UnstructuredPPTExtractor
from core.rag.extractor.unstructured.unstructured_pptx_extractor import UnstructuredPPTXExtractor
from core.rag.extractor.unstructured.unstructured_text_extractor import UnstructuredTextExtractor
from core.rag.extractor.unstructured.unstructured_xml_extractor import UnstructuredXmlExtractor
from core.rag.extractor.word_extractor import WordExtractor
from core.rag.models.document import Document
@@ -141,11 +140,7 @@ class ExtractProcessor:
extractor = UnstructuredEpubExtractor(file_path, unstructured_api_url, unstructured_api_key)
else:
# txt
extractor = (
UnstructuredTextExtractor(file_path, unstructured_api_url)
if is_automatic
else TextExtractor(file_path, autodetect_encoding=True)
)
extractor = TextExtractor(file_path, autodetect_encoding=True)
else:
if file_extension in {".xlsx", ".xls"}:
extractor = ExcelExtractor(file_path)

+ 3
- 1
api/core/rag/extractor/word_extractor.py Parādīt failu

@@ -267,8 +267,10 @@ class WordExtractor(BaseExtractor):
if isinstance(element.tag, str) and element.tag.endswith("p"): # paragraph
para = paragraphs.pop(0)
parsed_paragraph = parse_paragraph(para)
if parsed_paragraph:
if parsed_paragraph.strip():
content.append(parsed_paragraph)
else:
content.append("\n")
elif isinstance(element.tag, str) and element.tag.endswith("tbl"): # table
table = tables.pop(0)
content.append(self._table_to_markdown(table, image_map))

+ 2
- 3
api/core/rag/index_processor/constant/index_type.py Parādīt failu

@@ -1,8 +1,7 @@
from enum import Enum


class IndexType(Enum):
class IndexType(str, Enum):
PARAGRAPH_INDEX = "text_model"
QA_INDEX = "qa_model"
PARENT_CHILD_INDEX = "parent_child_index"
SUMMARY_INDEX = "summary_index"
PARENT_CHILD_INDEX = "hierarchical_model"

+ 14
- 11
api/core/rag/index_processor/index_processor_base.py Parādīt failu

@@ -27,10 +27,10 @@ class BaseIndexProcessor(ABC):
raise NotImplementedError

@abstractmethod
def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True):
def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True, **kwargs):
raise NotImplementedError

def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True):
def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True, **kwargs):
raise NotImplementedError

@abstractmethod
@@ -45,26 +45,29 @@ class BaseIndexProcessor(ABC):
) -> list[Document]:
raise NotImplementedError

def _get_splitter(self, processing_rule: dict, embedding_model_instance: Optional[ModelInstance]) -> TextSplitter:
def _get_splitter(
self,
processing_rule_mode: str,
max_tokens: int,
chunk_overlap: int,
separator: str,
embedding_model_instance: Optional[ModelInstance],
) -> TextSplitter:
"""
Get the NodeParser object according to the processing rule.
"""
character_splitter: TextSplitter
if processing_rule["mode"] == "custom":
if processing_rule_mode in ["custom", "hierarchical"]:
# The user-defined segmentation rule
rules = processing_rule["rules"]
segmentation = rules["segmentation"]
max_segmentation_tokens_length = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH
if segmentation["max_tokens"] < 50 or segmentation["max_tokens"] > max_segmentation_tokens_length:
if max_tokens < 50 or max_tokens > max_segmentation_tokens_length:
raise ValueError(f"Custom segment length should be between 50 and {max_segmentation_tokens_length}.")

separator = segmentation["separator"]
if separator:
separator = separator.replace("\\n", "\n")

character_splitter = FixedRecursiveCharacterTextSplitter.from_encoder(
chunk_size=segmentation["max_tokens"],
chunk_overlap=segmentation.get("chunk_overlap", 0) or 0,
chunk_size=max_tokens,
chunk_overlap=chunk_overlap,
fixed_separator=separator,
separators=["\n\n", "。", ". ", " ", ""],
embedding_model_instance=embedding_model_instance,

+ 5
- 2
api/core/rag/index_processor/index_processor_factory.py Parādīt failu

@@ -3,6 +3,7 @@
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor
from core.rag.index_processor.processor.parent_child_index_processor import ParentChildIndexProcessor
from core.rag.index_processor.processor.qa_index_processor import QAIndexProcessor


@@ -18,9 +19,11 @@ class IndexProcessorFactory:
if not self._index_type:
raise ValueError("Index type must be specified.")

if self._index_type == IndexType.PARAGRAPH_INDEX.value:
if self._index_type == IndexType.PARAGRAPH_INDEX:
return ParagraphIndexProcessor()
elif self._index_type == IndexType.QA_INDEX.value:
elif self._index_type == IndexType.QA_INDEX:
return QAIndexProcessor()
elif self._index_type == IndexType.PARENT_CHILD_INDEX:
return ParentChildIndexProcessor()
else:
raise ValueError(f"Index type {self._index_type} is not supported.")

+ 23
- 6
api/core/rag/index_processor/processor/paragraph_index_processor.py Parādīt failu

@@ -13,21 +13,34 @@ from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.models.document import Document
from core.tools.utils.text_processing_utils import remove_leading_symbols
from libs import helper
from models.dataset import Dataset
from models.dataset import Dataset, DatasetProcessRule
from services.entities.knowledge_entities.knowledge_entities import Rule


class ParagraphIndexProcessor(BaseIndexProcessor):
def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]:
text_docs = ExtractProcessor.extract(
extract_setting=extract_setting, is_automatic=kwargs.get("process_rule_mode") == "automatic"
extract_setting=extract_setting,
is_automatic=(
kwargs.get("process_rule_mode") == "automatic" or kwargs.get("process_rule_mode") == "hierarchical"
),
)

return text_docs

def transform(self, documents: list[Document], **kwargs) -> list[Document]:
process_rule = kwargs.get("process_rule")
if process_rule.get("mode") == "automatic":
automatic_rule = DatasetProcessRule.AUTOMATIC_RULES
rules = Rule(**automatic_rule)
else:
rules = Rule(**process_rule.get("rules"))
# Split the text documents into nodes.
splitter = self._get_splitter(
processing_rule=kwargs.get("process_rule", {}),
processing_rule_mode=process_rule.get("mode"),
max_tokens=rules.segmentation.max_tokens,
chunk_overlap=rules.segmentation.chunk_overlap,
separator=rules.segmentation.separator,
embedding_model_instance=kwargs.get("embedding_model_instance"),
)
all_documents = []
@@ -53,15 +66,19 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
all_documents.extend(split_documents)
return all_documents

def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True):
def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True, **kwargs):
if dataset.indexing_technique == "high_quality":
vector = Vector(dataset)
vector.create(documents)
if with_keywords:
keywords_list = kwargs.get("keywords_list")
keyword = Keyword(dataset)
keyword.create(documents)
if keywords_list and len(keywords_list) > 0:
keyword.add_texts(documents, keywords_list=keywords_list)
else:
keyword.add_texts(documents)

def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True):
def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True, **kwargs):
if dataset.indexing_technique == "high_quality":
vector = Vector(dataset)
if node_ids:

+ 189
- 0
api/core/rag/index_processor/processor/parent_child_index_processor.py Parādīt failu

@@ -0,0 +1,189 @@
"""Paragraph index processor."""

import uuid
from typing import Optional

from core.model_manager import ModelInstance
from core.rag.cleaner.clean_processor import CleanProcessor
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.extract_processor import ExtractProcessor
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.models.document import ChildDocument, Document
from extensions.ext_database import db
from libs import helper
from models.dataset import ChildChunk, Dataset, DocumentSegment
from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule


class ParentChildIndexProcessor(BaseIndexProcessor):
def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]:
text_docs = ExtractProcessor.extract(
extract_setting=extract_setting,
is_automatic=(
kwargs.get("process_rule_mode") == "automatic" or kwargs.get("process_rule_mode") == "hierarchical"
),
)

return text_docs

def transform(self, documents: list[Document], **kwargs) -> list[Document]:
process_rule = kwargs.get("process_rule")
rules = Rule(**process_rule.get("rules"))
all_documents = []
if rules.parent_mode == ParentMode.PARAGRAPH:
# Split the text documents into nodes.
splitter = self._get_splitter(
processing_rule_mode=process_rule.get("mode"),
max_tokens=rules.segmentation.max_tokens,
chunk_overlap=rules.segmentation.chunk_overlap,
separator=rules.segmentation.separator,
embedding_model_instance=kwargs.get("embedding_model_instance"),
)
for document in documents:
# document clean
document_text = CleanProcessor.clean(document.page_content, process_rule)
document.page_content = document_text
# parse document to nodes
document_nodes = splitter.split_documents([document])
split_documents = []
for document_node in document_nodes:
if document_node.page_content.strip():
doc_id = str(uuid.uuid4())
hash = helper.generate_text_hash(document_node.page_content)
document_node.metadata["doc_id"] = doc_id
document_node.metadata["doc_hash"] = hash
# delete Splitter character
page_content = document_node.page_content
if page_content.startswith(".") or page_content.startswith("。"):
page_content = page_content[1:].strip()
else:
page_content = page_content
if len(page_content) > 0:
document_node.page_content = page_content
# parse document to child nodes
child_nodes = self._split_child_nodes(
document_node, rules, process_rule.get("mode"), kwargs.get("embedding_model_instance")
)
document_node.children = child_nodes
split_documents.append(document_node)
all_documents.extend(split_documents)
elif rules.parent_mode == ParentMode.FULL_DOC:
page_content = "\n".join([document.page_content for document in documents])
document = Document(page_content=page_content, metadata=documents[0].metadata)
# parse document to child nodes
child_nodes = self._split_child_nodes(
document, rules, process_rule.get("mode"), kwargs.get("embedding_model_instance")
)
document.children = child_nodes
doc_id = str(uuid.uuid4())
hash = helper.generate_text_hash(document.page_content)
document.metadata["doc_id"] = doc_id
document.metadata["doc_hash"] = hash
all_documents.append(document)

return all_documents

def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True, **kwargs):
if dataset.indexing_technique == "high_quality":
vector = Vector(dataset)
for document in documents:
child_documents = document.children
if child_documents:
formatted_child_documents = [
Document(**child_document.model_dump()) for child_document in child_documents
]
vector.create(formatted_child_documents)

def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True, **kwargs):
# node_ids is segment's node_ids
if dataset.indexing_technique == "high_quality":
delete_child_chunks = kwargs.get("delete_child_chunks") or False
vector = Vector(dataset)
if node_ids:
child_node_ids = (
db.session.query(ChildChunk.index_node_id)
.join(DocumentSegment, ChildChunk.segment_id == DocumentSegment.id)
.filter(
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.index_node_id.in_(node_ids),
ChildChunk.dataset_id == dataset.id,
)
.all()
)
child_node_ids = [child_node_id[0] for child_node_id in child_node_ids]
vector.delete_by_ids(child_node_ids)
if delete_child_chunks:
db.session.query(ChildChunk).filter(
ChildChunk.dataset_id == dataset.id, ChildChunk.index_node_id.in_(child_node_ids)
).delete()
db.session.commit()
else:
vector.delete()

if delete_child_chunks:
db.session.query(ChildChunk).filter(ChildChunk.dataset_id == dataset.id).delete()
db.session.commit()

def retrieve(
self,
retrieval_method: str,
query: str,
dataset: Dataset,
top_k: int,
score_threshold: float,
reranking_model: dict,
) -> list[Document]:
# Set search parameters.
results = RetrievalService.retrieve(
retrieval_method=retrieval_method,
dataset_id=dataset.id,
query=query,
top_k=top_k,
score_threshold=score_threshold,
reranking_model=reranking_model,
)
# Organize results.
docs = []
for result in results:
metadata = result.metadata
metadata["score"] = result.score
if result.score > score_threshold:
doc = Document(page_content=result.page_content, metadata=metadata)
docs.append(doc)
return docs

def _split_child_nodes(
self,
document_node: Document,
rules: Rule,
process_rule_mode: str,
embedding_model_instance: Optional[ModelInstance],
) -> list[ChildDocument]:
child_splitter = self._get_splitter(
processing_rule_mode=process_rule_mode,
max_tokens=rules.subchunk_segmentation.max_tokens,
chunk_overlap=rules.subchunk_segmentation.chunk_overlap,
separator=rules.subchunk_segmentation.separator,
embedding_model_instance=embedding_model_instance,
)
# parse document to child nodes
child_nodes = []
child_documents = child_splitter.split_documents([document_node])
for child_document_node in child_documents:
if child_document_node.page_content.strip():
doc_id = str(uuid.uuid4())
hash = helper.generate_text_hash(child_document_node.page_content)
child_document = ChildDocument(
page_content=child_document_node.page_content, metadata=document_node.metadata
)
child_document.metadata["doc_id"] = doc_id
child_document.metadata["doc_hash"] = hash
child_page_content = child_document.page_content
if child_page_content.startswith(".") or child_page_content.startswith("。"):
child_page_content = child_page_content[1:].strip()
if len(child_page_content) > 0:
child_document.page_content = child_page_content
child_nodes.append(child_document)
return child_nodes

+ 41
- 22
api/core/rag/index_processor/processor/qa_index_processor.py Parādīt failu

@@ -21,18 +21,28 @@ from core.rag.models.document import Document
from core.tools.utils.text_processing_utils import remove_leading_symbols
from libs import helper
from models.dataset import Dataset
from services.entities.knowledge_entities.knowledge_entities import Rule


class QAIndexProcessor(BaseIndexProcessor):
def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]:
text_docs = ExtractProcessor.extract(
extract_setting=extract_setting, is_automatic=kwargs.get("process_rule_mode") == "automatic"
extract_setting=extract_setting,
is_automatic=(
kwargs.get("process_rule_mode") == "automatic" or kwargs.get("process_rule_mode") == "hierarchical"
),
)
return text_docs

def transform(self, documents: list[Document], **kwargs) -> list[Document]:
preview = kwargs.get("preview")
process_rule = kwargs.get("process_rule")
rules = Rule(**process_rule.get("rules"))
splitter = self._get_splitter(
processing_rule=kwargs.get("process_rule") or {},
processing_rule_mode=process_rule.get("mode"),
max_tokens=rules.segmentation.max_tokens,
chunk_overlap=rules.segmentation.chunk_overlap,
separator=rules.segmentation.separator,
embedding_model_instance=kwargs.get("embedding_model_instance"),
)

@@ -59,24 +69,33 @@ class QAIndexProcessor(BaseIndexProcessor):
document_node.page_content = remove_leading_symbols(page_content)
split_documents.append(document_node)
all_documents.extend(split_documents)
for i in range(0, len(all_documents), 10):
threads = []
sub_documents = all_documents[i : i + 10]
for doc in sub_documents:
document_format_thread = threading.Thread(
target=self._format_qa_document,
kwargs={
"flask_app": current_app._get_current_object(), # type: ignore
"tenant_id": kwargs.get("tenant_id"),
"document_node": doc,
"all_qa_documents": all_qa_documents,
"document_language": kwargs.get("doc_language", "English"),
},
)
threads.append(document_format_thread)
document_format_thread.start()
for thread in threads:
thread.join()
if preview:
self._format_qa_document(
current_app._get_current_object(),
kwargs.get("tenant_id"),
all_documents[0],
all_qa_documents,
kwargs.get("doc_language", "English"),
)
else:
for i in range(0, len(all_documents), 10):
threads = []
sub_documents = all_documents[i : i + 10]
for doc in sub_documents:
document_format_thread = threading.Thread(
target=self._format_qa_document,
kwargs={
"flask_app": current_app._get_current_object(),
"tenant_id": kwargs.get("tenant_id"),
"document_node": doc,
"all_qa_documents": all_qa_documents,
"document_language": kwargs.get("doc_language", "English"),
},
)
threads.append(document_format_thread)
document_format_thread.start()
for thread in threads:
thread.join()
return all_qa_documents

def format_by_template(self, file: FileStorage, **kwargs) -> list[Document]:
@@ -98,12 +117,12 @@ class QAIndexProcessor(BaseIndexProcessor):
raise ValueError(str(e))
return text_docs

def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True):
def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True, **kwargs):
if dataset.indexing_technique == "high_quality":
vector = Vector(dataset)
vector.create(documents)

def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True):
def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True, **kwargs):
vector = Vector(dataset)
if node_ids:
vector.delete_by_ids(node_ids)

+ 15
- 0
api/core/rag/models/document.py Parādīt failu

@@ -5,6 +5,19 @@ from typing import Any, Optional
from pydantic import BaseModel, Field


class ChildDocument(BaseModel):
"""Class for storing a piece of text and associated metadata."""

page_content: str

vector: Optional[list[float]] = None

"""Arbitrary metadata about the page content (e.g., source, relationships to other
documents, etc.).
"""
metadata: Optional[dict] = Field(default_factory=dict)


class Document(BaseModel):
"""Class for storing a piece of text and associated metadata."""

@@ -19,6 +32,8 @@ class Document(BaseModel):

provider: Optional[str] = "dify"

children: Optional[list[ChildDocument]] = None


class BaseDocumentTransformer(ABC):
"""Abstract base class for document transformation systems.

+ 9
- 23
api/core/rag/retrieval/dataset_retrieval.py Parādīt failu

@@ -166,43 +166,29 @@ class DatasetRetrieval:
"content": item.page_content,
}
retrieval_resource_list.append(source)
document_score_list = {}
# deal with dify documents
if dify_documents:
for item in dify_documents:
if item.metadata.get("score"):
document_score_list[item.metadata["doc_id"]] = item.metadata["score"]

index_node_ids = [document.metadata["doc_id"] for document in dify_documents]
segments = DocumentSegment.query.filter(
DocumentSegment.dataset_id.in_(dataset_ids),
DocumentSegment.status == "completed",
DocumentSegment.enabled == True,
DocumentSegment.index_node_id.in_(index_node_ids),
).all()

if segments:
index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)}
sorted_segments = sorted(
segments, key=lambda segment: index_node_id_to_position.get(segment.index_node_id, float("inf"))
)
for segment in sorted_segments:
records = RetrievalService.format_retrieval_documents(dify_documents)
if records:
for record in records:
segment = record.segment
if segment.answer:
document_context_list.append(
DocumentContext(
content=f"question:{segment.get_sign_content()} answer:{segment.answer}",
score=document_score_list.get(segment.index_node_id, None),
score=record.score,
)
)
else:
document_context_list.append(
DocumentContext(
content=segment.get_sign_content(),
score=document_score_list.get(segment.index_node_id, None),
score=record.score,
)
)
if show_retrieve_source:
for segment in sorted_segments:
for record in records:
segment = record.segment
dataset = Dataset.query.filter_by(id=segment.dataset_id).first()
document = DatasetDocument.query.filter(
DatasetDocument.id == segment.document_id,
@@ -218,7 +204,7 @@ class DatasetRetrieval:
"data_source_type": document.data_source_type,
"segment_id": segment.id,
"retriever_from": invoke_from.to_source(),
"score": document_score_list.get(segment.index_node_id, 0.0),
"score": record.score or 0.0,
}

if invoke_from.to_source() == "dev":

+ 8
- 26
api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py Parādīt failu

@@ -11,6 +11,7 @@ from core.entities.model_entities import ModelStatus
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.variables import StringSegment
@@ -18,7 +19,7 @@ from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from extensions.ext_database import db
from models.dataset import Dataset, Document, DocumentSegment
from models.dataset import Dataset, Document
from models.workflow import WorkflowNodeExecutionStatus

from .entities import KnowledgeRetrievalNodeData
@@ -211,29 +212,12 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
"content": item.page_content,
}
retrieval_resource_list.append(source)
document_score_list: dict[str, float] = {}
# deal with dify documents
if dify_documents:
document_score_list = {}
for item in dify_documents:
if item.metadata.get("score"):
document_score_list[item.metadata["doc_id"]] = item.metadata["score"]

index_node_ids = [document.metadata["doc_id"] for document in dify_documents]
segments = DocumentSegment.query.filter(
DocumentSegment.dataset_id.in_(dataset_ids),
DocumentSegment.completed_at.isnot(None),
DocumentSegment.status == "completed",
DocumentSegment.enabled == True,
DocumentSegment.index_node_id.in_(index_node_ids),
).all()
if segments:
index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)}
sorted_segments = sorted(
segments, key=lambda segment: index_node_id_to_position.get(segment.index_node_id, float("inf"))
)

for segment in sorted_segments:
records = RetrievalService.format_retrieval_documents(dify_documents)
if records:
for record in records:
segment = record.segment
dataset = Dataset.query.filter_by(id=segment.dataset_id).first()
document = Document.query.filter(
Document.id == segment.document_id,
@@ -251,7 +235,7 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
"document_data_source_type": document.data_source_type,
"segment_id": segment.id,
"retriever_from": "workflow",
"score": document_score_list.get(segment.index_node_id, None),
"score": record.score or 0.0,
"segment_hit_count": segment.hit_count,
"segment_word_count": segment.word_count,
"segment_position": segment.position,
@@ -270,10 +254,8 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
key=lambda x: x["metadata"]["score"] if x["metadata"].get("score") is not None else 0.0,
reverse=True,
)
position = 1
for item in retrieval_resource_list:
for position, item in enumerate(retrieval_resource_list, start=1):
item["metadata"]["position"] = position
position += 1
return retrieval_resource_list

@classmethod

+ 1
- 0
api/fields/dataset_fields.py Parādīt failu

@@ -73,6 +73,7 @@ dataset_detail_fields = {
"embedding_available": fields.Boolean,
"retrieval_model_dict": fields.Nested(dataset_retrieval_model_fields),
"tags": fields.List(fields.Nested(tag_fields)),
"doc_form": fields.String,
"external_knowledge_info": fields.Nested(external_knowledge_info_fields),
"external_retrieval_model": fields.Nested(external_retrieval_model_fields, allow_null=True),
}

+ 1
- 0
api/fields/document_fields.py Parādīt failu

@@ -34,6 +34,7 @@ document_with_segments_fields = {
"data_source_info": fields.Raw(attribute="data_source_info_dict"),
"data_source_detail_dict": fields.Raw(attribute="data_source_detail_dict"),
"dataset_process_rule_id": fields.String,
"process_rule_dict": fields.Raw(attribute="process_rule_dict"),
"name": fields.String,
"created_from": fields.String,
"created_by": fields.String,

+ 8
- 0
api/fields/hit_testing_fields.py Parādīt failu

@@ -34,8 +34,16 @@ segment_fields = {
"document": fields.Nested(document_fields),
}

child_chunk_fields = {
"id": fields.String,
"content": fields.String,
"position": fields.Integer,
"score": fields.Float,
}

hit_testing_record_fields = {
"segment": fields.Nested(segment_fields),
"child_chunks": fields.List(fields.Nested(child_chunk_fields)),
"score": fields.Float,
"tsne_position": fields.Raw,
}

+ 14
- 0
api/fields/segment_fields.py Parādīt failu

@@ -2,6 +2,17 @@ from flask_restful import fields # type: ignore

from libs.helper import TimestampField

child_chunk_fields = {
"id": fields.String,
"segment_id": fields.String,
"content": fields.String,
"position": fields.Integer,
"word_count": fields.Integer,
"type": fields.String,
"created_at": TimestampField,
"updated_at": TimestampField,
}

segment_fields = {
"id": fields.String,
"position": fields.Integer,
@@ -20,10 +31,13 @@ segment_fields = {
"status": fields.String,
"created_by": fields.String,
"created_at": TimestampField,
"updated_at": TimestampField,
"updated_by": fields.String,
"indexing_at": TimestampField,
"completed_at": TimestampField,
"error": fields.String,
"stopped_at": TimestampField,
"child_chunks": fields.List(fields.Nested(child_chunk_fields)),
}

segment_list_response = {

+ 55
- 0
api/migrations/versions/2024_11_22_0701-e19037032219_parent_child_index.py Parādīt failu

@@ -0,0 +1,55 @@
"""parent-child-index

Revision ID: e19037032219
Revises: 01d6889832f7
Create Date: 2024-11-22 07:01:17.550037

"""
from alembic import op
import models as models
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision = 'e19037032219'
down_revision = 'd7999dfa4aae'
branch_labels = None
depends_on = None


def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('child_chunks',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
sa.Column('dataset_id', models.types.StringUUID(), nullable=False),
sa.Column('document_id', models.types.StringUUID(), nullable=False),
sa.Column('segment_id', models.types.StringUUID(), nullable=False),
sa.Column('position', sa.Integer(), nullable=False),
sa.Column('content', sa.Text(), nullable=False),
sa.Column('word_count', sa.Integer(), nullable=False),
sa.Column('index_node_id', sa.String(length=255), nullable=True),
sa.Column('index_node_hash', sa.String(length=255), nullable=True),
sa.Column('type', sa.String(length=255), server_default=sa.text("'automatic'::character varying"), nullable=False),
sa.Column('created_by', models.types.StringUUID(), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
sa.Column('updated_by', models.types.StringUUID(), nullable=True),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
sa.Column('indexing_at', sa.DateTime(), nullable=True),
sa.Column('completed_at', sa.DateTime(), nullable=True),
sa.Column('error', sa.Text(), nullable=True),
sa.PrimaryKeyConstraint('id', name='child_chunk_pkey')
)
with op.batch_alter_table('child_chunks', schema=None) as batch_op:
batch_op.create_index('child_chunk_dataset_id_idx', ['tenant_id', 'dataset_id', 'document_id', 'segment_id', 'index_node_id'], unique=False)

# ### end Alembic commands ###


def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('child_chunks', schema=None) as batch_op:
batch_op.drop_index('child_chunk_dataset_id_idx')

op.drop_table('child_chunks')
# ### end Alembic commands ###

+ 47
- 0
api/migrations/versions/2024_12_25_1137-923752d42eb6_add_auto_disabled_dataset_logs.py Parādīt failu

@@ -0,0 +1,47 @@
"""add_auto_disabled_dataset_logs

Revision ID: 923752d42eb6
Revises: e19037032219
Create Date: 2024-12-25 11:37:55.467101

"""
from alembic import op
import models as models
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql

# revision identifiers, used by Alembic.
revision = '923752d42eb6'
down_revision = 'e19037032219'
branch_labels = None
depends_on = None


def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('dataset_auto_disable_logs',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
sa.Column('dataset_id', models.types.StringUUID(), nullable=False),
sa.Column('document_id', models.types.StringUUID(), nullable=False),
sa.Column('notified', sa.Boolean(), server_default=sa.text('false'), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
sa.PrimaryKeyConstraint('id', name='dataset_auto_disable_log_pkey')
)
with op.batch_alter_table('dataset_auto_disable_logs', schema=None) as batch_op:
batch_op.create_index('dataset_auto_disable_log_created_atx', ['created_at'], unique=False)
batch_op.create_index('dataset_auto_disable_log_dataset_idx', ['dataset_id'], unique=False)
batch_op.create_index('dataset_auto_disable_log_tenant_idx', ['tenant_id'], unique=False)

# ### end Alembic commands ###


def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('dataset_auto_disable_logs', schema=None) as batch_op:
batch_op.drop_index('dataset_auto_disable_log_tenant_idx')
batch_op.drop_index('dataset_auto_disable_log_dataset_idx')
batch_op.drop_index('dataset_auto_disable_log_created_atx')

op.drop_table('dataset_auto_disable_logs')
# ### end Alembic commands ###

+ 84
- 3
api/models/dataset.py Parādīt failu

@@ -17,6 +17,7 @@ from sqlalchemy.dialects.postgresql import JSONB
from configs import dify_config
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_storage import storage
from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule

from .account import Account
from .engine import db
@@ -215,7 +216,7 @@ class DatasetProcessRule(db.Model): # type: ignore[name-defined]
created_by = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())

MODES = ["automatic", "custom"]
MODES = ["automatic", "custom", "hierarchical"]
PRE_PROCESSING_RULES = ["remove_stopwords", "remove_extra_spaces", "remove_urls_emails"]
AUTOMATIC_RULES: dict[str, Any] = {
"pre_processing_rules": [
@@ -231,8 +232,6 @@ class DatasetProcessRule(db.Model): # type: ignore[name-defined]
"dataset_id": self.dataset_id,
"mode": self.mode,
"rules": self.rules_dict,
"created_by": self.created_by,
"created_at": self.created_at,
}

@property
@@ -396,6 +395,12 @@ class Document(db.Model): # type: ignore[name-defined]
.scalar()
)

@property
def process_rule_dict(self):
if self.dataset_process_rule_id:
return self.dataset_process_rule.to_dict()
return None

def to_dict(self):
return {
"id": self.id,
@@ -560,6 +565,24 @@ class DocumentSegment(db.Model): # type: ignore[name-defined]
.first()
)

@property
def child_chunks(self):
process_rule = self.document.dataset_process_rule
if process_rule.mode == "hierarchical":
rules = Rule(**process_rule.rules_dict)
if rules.parent_mode and rules.parent_mode != ParentMode.FULL_DOC:
child_chunks = (
db.session.query(ChildChunk)
.filter(ChildChunk.segment_id == self.id)
.order_by(ChildChunk.position.asc())
.all()
)
return child_chunks or []
else:
return []
else:
return []

def get_sign_content(self):
signed_urls = []
text = self.content
@@ -605,6 +628,47 @@ class DocumentSegment(db.Model): # type: ignore[name-defined]
return text


class ChildChunk(db.Model):
__tablename__ = "child_chunks"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="child_chunk_pkey"),
db.Index("child_chunk_dataset_id_idx", "tenant_id", "dataset_id", "document_id", "segment_id", "index_node_id"),
)

# initial fields
id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
tenant_id = db.Column(StringUUID, nullable=False)
dataset_id = db.Column(StringUUID, nullable=False)
document_id = db.Column(StringUUID, nullable=False)
segment_id = db.Column(StringUUID, nullable=False)
position = db.Column(db.Integer, nullable=False)
content = db.Column(db.Text, nullable=False)
word_count = db.Column(db.Integer, nullable=False)
# indexing fields
index_node_id = db.Column(db.String(255), nullable=True)
index_node_hash = db.Column(db.String(255), nullable=True)
type = db.Column(db.String(255), nullable=False, server_default=db.text("'automatic'::character varying"))
created_by = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
updated_by = db.Column(StringUUID, nullable=True)
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
indexing_at = db.Column(db.DateTime, nullable=True)
completed_at = db.Column(db.DateTime, nullable=True)
error = db.Column(db.Text, nullable=True)

@property
def dataset(self):
return db.session.query(Dataset).filter(Dataset.id == self.dataset_id).first()

@property
def document(self):
return db.session.query(Document).filter(Document.id == self.document_id).first()

@property
def segment(self):
return db.session.query(DocumentSegment).filter(DocumentSegment.id == self.segment_id).first()


class AppDatasetJoin(db.Model): # type: ignore[name-defined]
__tablename__ = "app_dataset_joins"
__table_args__ = (
@@ -844,3 +908,20 @@ class ExternalKnowledgeBindings(db.Model): # type: ignore[name-defined]
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_by = db.Column(StringUUID, nullable=True)
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())


class DatasetAutoDisableLog(db.Model):
__tablename__ = "dataset_auto_disable_logs"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="dataset_auto_disable_log_pkey"),
db.Index("dataset_auto_disable_log_tenant_idx", "tenant_id"),
db.Index("dataset_auto_disable_log_dataset_idx", "dataset_id"),
db.Index("dataset_auto_disable_log_created_atx", "created_at"),
)

id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
tenant_id = db.Column(StringUUID, nullable=False)
dataset_id = db.Column(StringUUID, nullable=False)
document_id = db.Column(StringUUID, nullable=False)
notified = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))

+ 35
- 1
api/schedule/clean_unused_datasets_task.py Parādīt failu

@@ -10,7 +10,7 @@ from configs import dify_config
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import Dataset, DatasetQuery, Document
from models.dataset import Dataset, DatasetAutoDisableLog, DatasetQuery, Document
from services.feature_service import FeatureService


@@ -75,6 +75,23 @@ def clean_unused_datasets_task():
)
if not dataset_query or len(dataset_query) == 0:
try:
# add auto disable log
documents = (
db.session.query(Document)
.filter(
Document.dataset_id == dataset.id,
Document.enabled == True,
Document.archived == False,
)
.all()
)
for document in documents:
dataset_auto_disable_log = DatasetAutoDisableLog(
tenant_id=dataset.tenant_id,
dataset_id=dataset.id,
document_id=document.id,
)
db.session.add(dataset_auto_disable_log)
# remove index
index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor()
index_processor.clean(dataset, None)
@@ -151,6 +168,23 @@ def clean_unused_datasets_task():
else:
plan = plan_cache.decode()
if plan == "sandbox":
# add auto disable log
documents = (
db.session.query(Document)
.filter(
Document.dataset_id == dataset.id,
Document.enabled == True,
Document.archived == False,
)
.all()
)
for document in documents:
dataset_auto_disable_log = DatasetAutoDisableLog(
tenant_id=dataset.tenant_id,
dataset_id=dataset.id,
document_id=document.id,
)
db.session.add(dataset_auto_disable_log)
# remove index
index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor()
index_processor.clean(dataset, None)

+ 66
- 0
api/schedule/mail_clean_document_notify_task.py Parādīt failu

@@ -0,0 +1,66 @@
import logging
import time

import click
from celery import shared_task
from flask import render_template

from extensions.ext_mail import mail
from models.account import Account, Tenant, TenantAccountJoin
from models.dataset import Dataset, DatasetAutoDisableLog


@shared_task(queue="mail")
def send_document_clean_notify_task():
"""
Async Send document clean notify mail

Usage: send_document_clean_notify_task.delay()
"""
if not mail.is_inited():
return

logging.info(click.style("Start send document clean notify mail", fg="green"))
start_at = time.perf_counter()

# send document clean notify mail
try:
dataset_auto_disable_logs = DatasetAutoDisableLog.query.filter(DatasetAutoDisableLog.notified == False).all()
# group by tenant_id
dataset_auto_disable_logs_map = {}
for dataset_auto_disable_log in dataset_auto_disable_logs:
dataset_auto_disable_logs_map[dataset_auto_disable_log.tenant_id].append(dataset_auto_disable_log)

for tenant_id, tenant_dataset_auto_disable_logs in dataset_auto_disable_logs_map.items():
knowledge_details = []
tenant = Tenant.query.filter(Tenant.id == tenant_id).first()
if not tenant:
continue
current_owner_join = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, role="owner").first()
account = Account.query.filter(Account.id == current_owner_join.account_id).first()
if not account:
continue

dataset_auto_dataset_map = {}
for dataset_auto_disable_log in tenant_dataset_auto_disable_logs:
dataset_auto_dataset_map[dataset_auto_disable_log.dataset_id].append(
dataset_auto_disable_log.document_id
)

for dataset_id, document_ids in dataset_auto_dataset_map.items():
dataset = Dataset.query.filter(Dataset.id == dataset_id).first()
if dataset:
document_count = len(document_ids)
knowledge_details.append(f"<li>Knowledge base {dataset.name}: {document_count} documents</li>")

html_content = render_template(
"clean_document_job_mail_template-US.html",
)
mail.send(to=to, subject="立即加入 Dify 工作空间", html=html_content)

end_at = time.perf_counter()
logging.info(
click.style("Send document clean notify mail succeeded: latency: {}".format(end_at - start_at), fg="green")
)
except Exception:
logging.exception("Send invite member mail to {} failed".format(to))

+ 522
- 218
api/services/dataset_service.py
Failā izmaiņas netiks attēlotas, jo tās ir par lielu
Parādīt failu


+ 111
- 1
api/services/entities/knowledge_entities/knowledge_entities.py Parādīt failu

@@ -1,4 +1,5 @@
from typing import Optional
from enum import Enum
from typing import Literal, Optional

from pydantic import BaseModel

@@ -8,3 +9,112 @@ class SegmentUpdateEntity(BaseModel):
answer: Optional[str] = None
keywords: Optional[list[str]] = None
enabled: Optional[bool] = None


class ParentMode(str, Enum):
FULL_DOC = "full-doc"
PARAGRAPH = "paragraph"


class NotionIcon(BaseModel):
type: str
url: Optional[str] = None
emoji: Optional[str] = None


class NotionPage(BaseModel):
page_id: str
page_name: str
page_icon: Optional[NotionIcon] = None
type: str


class NotionInfo(BaseModel):
workspace_id: str
pages: list[NotionPage]


class WebsiteInfo(BaseModel):
provider: str
job_id: str
urls: list[str]
only_main_content: bool = True


class FileInfo(BaseModel):
file_ids: list[str]


class InfoList(BaseModel):
data_source_type: Literal["upload_file", "notion_import", "website_crawl"]
notion_info_list: Optional[list[NotionInfo]] = None
file_info_list: Optional[FileInfo] = None
website_info_list: Optional[WebsiteInfo] = None


class DataSource(BaseModel):
info_list: InfoList


class PreProcessingRule(BaseModel):
id: str
enabled: bool


class Segmentation(BaseModel):
separator: str = "\n"
max_tokens: int
chunk_overlap: int = 0


class Rule(BaseModel):
pre_processing_rules: Optional[list[PreProcessingRule]] = None
segmentation: Optional[Segmentation] = None
parent_mode: Optional[Literal["full-doc", "paragraph"]] = None
subchunk_segmentation: Optional[Segmentation] = None


class ProcessRule(BaseModel):
mode: Literal["automatic", "custom", "hierarchical"]
rules: Optional[Rule] = None


class RerankingModel(BaseModel):
reranking_provider_name: Optional[str] = None
reranking_model_name: Optional[str] = None


class RetrievalModel(BaseModel):
search_method: Literal["hybrid_search", "semantic_search", "full_text_search"]
reranking_enable: bool
reranking_model: Optional[RerankingModel] = None
top_k: int
score_threshold_enabled: bool
score_threshold: Optional[float] = None


class KnowledgeConfig(BaseModel):
original_document_id: Optional[str] = None
duplicate: bool = True
indexing_technique: Literal["high_quality", "economy"]
data_source: Optional[DataSource] = None
process_rule: Optional[ProcessRule] = None
retrieval_model: Optional[RetrievalModel] = None
doc_form: str = "text_model"
doc_language: str = "English"
embedding_model: Optional[str] = None
embedding_model_provider: Optional[str] = None
name: Optional[str] = None


class SegmentUpdateArgs(BaseModel):
content: Optional[str] = None
answer: Optional[str] = None
keywords: Optional[list[str]] = None
regenerate_child_chunks: bool = False
enabled: Optional[bool] = None


class ChildChunkUpdateArgs(BaseModel):
id: Optional[str] = None
content: str

+ 9
- 0
api/services/errors/chunk.py Parādīt failu

@@ -0,0 +1,9 @@
from services.errors.base import BaseServiceError


class ChildChunkIndexingError(BaseServiceError):
description = "{message}"


class ChildChunkDeleteIndexError(BaseServiceError):
description = "{message}"

+ 5
- 32
api/services/hit_testing_service.py Parādīt failu

@@ -7,7 +7,7 @@ from core.rag.models.document import Document
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db
from models.account import Account
from models.dataset import Dataset, DatasetQuery, DocumentSegment
from models.dataset import Dataset, DatasetQuery

default_retrieval_model = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
@@ -69,7 +69,7 @@ class HitTestingService:
db.session.add(dataset_query)
db.session.commit()

return dict(cls.compact_retrieve_response(dataset, query, all_documents))
return cls.compact_retrieve_response(query, all_documents)

@classmethod
def external_retrieve(
@@ -106,41 +106,14 @@ class HitTestingService:
return dict(cls.compact_external_retrieve_response(dataset, query, all_documents))

@classmethod
def compact_retrieve_response(cls, dataset: Dataset, query: str, documents: list[Document]):
records = []

for document in documents:
if document.metadata is None:
continue

index_node_id = document.metadata["doc_id"]

segment = (
db.session.query(DocumentSegment)
.filter(
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
DocumentSegment.index_node_id == index_node_id,
)
.first()
)

if not segment:
continue

record = {
"segment": segment,
"score": document.metadata.get("score", None),
}

records.append(record)
def compact_retrieve_response(cls, query: str, documents: list[Document]):
records = RetrievalService.format_retrieval_documents(documents)

return {
"query": {
"content": query,
},
"records": records,
"records": [record.model_dump() for record in records],
}

@classmethod

+ 171
- 23
api/services/vector_service.py Parādīt failu

@@ -1,40 +1,68 @@
from typing import Optional

from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import Document
from models.dataset import Dataset, DocumentSegment
from extensions.ext_database import db
from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment
from models.dataset import Document as DatasetDocument
from services.entities.knowledge_entities.knowledge_entities import ParentMode


class VectorService:
@classmethod
def create_segments_vector(
cls, keywords_list: Optional[list[list[str]]], segments: list[DocumentSegment], dataset: Dataset
cls, keywords_list: Optional[list[list[str]]], segments: list[DocumentSegment], dataset: Dataset, doc_form: str
):
documents = []
for segment in segments:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
)
documents.append(document)
if dataset.indexing_technique == "high_quality":
# save vector index
vector = Vector(dataset=dataset)
vector.add_texts(documents, duplicate_check=True)

# save keyword index
keyword = Keyword(dataset)
for segment in segments:
if doc_form == IndexType.PARENT_CHILD_INDEX:
document = DatasetDocument.query.filter_by(id=segment.document_id).first()
# get the process rule
processing_rule = (
db.session.query(DatasetProcessRule)
.filter(DatasetProcessRule.id == document.dataset_process_rule_id)
.first()
)
# get embedding model instance
if dataset.indexing_technique == "high_quality":
# check embedding model setting
model_manager = ModelManager()

if keywords_list and len(keywords_list) > 0:
keyword.add_texts(documents, keywords_list=keywords_list)
else:
keyword.add_texts(documents)
if dataset.embedding_model_provider:
embedding_model_instance = model_manager.get_model_instance(
tenant_id=dataset.tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
)
else:
embedding_model_instance = model_manager.get_default_model_instance(
tenant_id=dataset.tenant_id,
model_type=ModelType.TEXT_EMBEDDING,
)
else:
raise ValueError("The knowledge base index technique is not high quality!")
cls.generate_child_chunks(segment, document, dataset, embedding_model_instance, processing_rule, False)
else:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
)
documents.append(document)
if len(documents) > 0:
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
index_processor.load(dataset, documents, with_keywords=True, keywords_list=keywords_list)

@classmethod
def update_segment_vector(cls, keywords: Optional[list[str]], segment: DocumentSegment, dataset: Dataset):
@@ -65,3 +93,123 @@ class VectorService:
keyword.add_texts([document], keywords_list=[keywords])
else:
keyword.add_texts([document])

@classmethod
def generate_child_chunks(
cls,
segment: DocumentSegment,
dataset_document: Document,
dataset: Dataset,
embedding_model_instance: ModelInstance,
processing_rule: DatasetProcessRule,
regenerate: bool = False,
):
index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor()
if regenerate:
# delete child chunks
index_processor.clean(dataset, [segment.index_node_id], with_keywords=True, delete_child_chunks=True)

# generate child chunks
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
)
# use full doc mode to generate segment's child chunk
processing_rule_dict = processing_rule.to_dict()
processing_rule_dict["rules"]["parent_mode"] = ParentMode.FULL_DOC.value
documents = index_processor.transform(
[document],
embedding_model_instance=embedding_model_instance,
process_rule=processing_rule_dict,
tenant_id=dataset.tenant_id,
doc_language=dataset_document.doc_language,
)
# save child chunks
if len(documents) > 0 and len(documents[0].children) > 0:
index_processor.load(dataset, documents)

for position, child_chunk in enumerate(documents[0].children, start=1):
child_segment = ChildChunk(
tenant_id=dataset.tenant_id,
dataset_id=dataset.id,
document_id=dataset_document.id,
segment_id=segment.id,
position=position,
index_node_id=child_chunk.metadata["doc_id"],
index_node_hash=child_chunk.metadata["doc_hash"],
content=child_chunk.page_content,
word_count=len(child_chunk.page_content),
type="automatic",
created_by=dataset_document.created_by,
)
db.session.add(child_segment)
db.session.commit()

@classmethod
def create_child_chunk_vector(cls, child_segment: ChildChunk, dataset: Dataset):
child_document = Document(
page_content=child_segment.content,
metadata={
"doc_id": child_segment.index_node_id,
"doc_hash": child_segment.index_node_hash,
"document_id": child_segment.document_id,
"dataset_id": child_segment.dataset_id,
},
)
if dataset.indexing_technique == "high_quality":
# save vector index
vector = Vector(dataset=dataset)
vector.add_texts([child_document], duplicate_check=True)

@classmethod
def update_child_chunk_vector(
cls,
new_child_chunks: list[ChildChunk],
update_child_chunks: list[ChildChunk],
delete_child_chunks: list[ChildChunk],
dataset: Dataset,
):
documents = []
delete_node_ids = []
for new_child_chunk in new_child_chunks:
new_child_document = Document(
page_content=new_child_chunk.content,
metadata={
"doc_id": new_child_chunk.index_node_id,
"doc_hash": new_child_chunk.index_node_hash,
"document_id": new_child_chunk.document_id,
"dataset_id": new_child_chunk.dataset_id,
},
)
documents.append(new_child_document)
for update_child_chunk in update_child_chunks:
child_document = Document(
page_content=update_child_chunk.content,
metadata={
"doc_id": update_child_chunk.index_node_id,
"doc_hash": update_child_chunk.index_node_hash,
"document_id": update_child_chunk.document_id,
"dataset_id": update_child_chunk.dataset_id,
},
)
documents.append(child_document)
delete_node_ids.append(update_child_chunk.index_node_id)
for delete_child_chunk in delete_child_chunks:
delete_node_ids.append(delete_child_chunk.index_node_id)
if dataset.indexing_technique == "high_quality":
# update vector index
vector = Vector(dataset=dataset)
if delete_node_ids:
vector.delete_by_ids(delete_node_ids)
if documents:
vector.add_texts(documents, duplicate_check=True)

@classmethod
def delete_child_chunk_vector(cls, child_chunk: ChildChunk, dataset: Dataset):
vector = Vector(dataset=dataset)
vector.delete_by_ids([child_chunk.index_node_id])

+ 25
- 3
api/tasks/add_document_to_index_task.py Parādīt failu

@@ -6,12 +6,13 @@ import click
from celery import shared_task # type: ignore
from werkzeug.exceptions import NotFound

from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import Document
from core.rag.models.document import ChildDocument, Document
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import DatasetAutoDisableLog, DocumentSegment
from models.dataset import Document as DatasetDocument
from models.dataset import DocumentSegment


@shared_task(queue="dataset")
@@ -53,7 +54,22 @@ def add_document_to_index_task(dataset_document_id: str):
"dataset_id": segment.dataset_id,
},
)

if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
child_chunks = segment.child_chunks
if child_chunks:
child_documents = []
for child_chunk in child_chunks:
child_document = ChildDocument(
page_content=child_chunk.content,
metadata={
"doc_id": child_chunk.index_node_id,
"doc_hash": child_chunk.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
)
child_documents.append(child_document)
document.children = child_documents
documents.append(document)

dataset = dataset_document.dataset
@@ -65,6 +81,12 @@ def add_document_to_index_task(dataset_document_id: str):
index_processor = IndexProcessorFactory(index_type).init_index_processor()
index_processor.load(dataset, documents)

# delete auto disable log
db.session.query(DatasetAutoDisableLog).filter(
DatasetAutoDisableLog.document_id == dataset_document.id
).delete()
db.session.commit()

end_at = time.perf_counter()
logging.info(
click.style(

+ 75
- 0
api/tasks/batch_clean_document_task.py Parādīt failu

@@ -0,0 +1,75 @@
import logging
import time

import click
from celery import shared_task

from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.tools.utils.web_reader_tool import get_image_upload_file_ids
from extensions.ext_database import db
from extensions.ext_storage import storage
from models.dataset import Dataset, DocumentSegment
from models.model import UploadFile


@shared_task(queue="dataset")
def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form: str, file_ids: list[str]):
"""
Clean document when document deleted.
:param document_ids: document ids
:param dataset_id: dataset id
:param doc_form: doc_form
:param file_ids: file ids

Usage: clean_document_task.delay(document_id, dataset_id)
"""
logging.info(click.style("Start batch clean documents when documents deleted", fg="green"))
start_at = time.perf_counter()

try:
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()

if not dataset:
raise Exception("Document has no dataset")

segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id.in_(document_ids)).all()
# check segment is exist
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)

for segment in segments:
image_upload_file_ids = get_image_upload_file_ids(segment.content)
for upload_file_id in image_upload_file_ids:
image_file = db.session.query(UploadFile).filter(UploadFile.id == upload_file_id).first()
try:
storage.delete(image_file.key)
except Exception:
logging.exception(
"Delete image_files failed when storage deleted, \
image_upload_file_is: {}".format(upload_file_id)
)
db.session.delete(image_file)
db.session.delete(segment)

db.session.commit()
if file_ids:
files = db.session.query(UploadFile).filter(UploadFile.id.in_(file_ids)).all()
for file in files:
try:
storage.delete(file.key)
except Exception:
logging.exception("Delete file failed when document deleted, file_id: {}".format(file.id))
db.session.delete(file)
db.session.commit()

end_at = time.perf_counter()
logging.info(
click.style(
"Cleaned documents when documents deleted latency: {}".format(end_at - start_at),
fg="green",
)
)
except Exception:
logging.exception("Cleaned documents when documents deleted failed")

+ 2
- 3
api/tasks/batch_create_segment_to_index_task.py Parādīt failu

@@ -7,13 +7,13 @@ import click
from celery import shared_task # type: ignore
from sqlalchemy import func

from core.indexing_runner import IndexingRunner
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs import helper
from models.dataset import Dataset, Document, DocumentSegment
from services.vector_service import VectorService


@shared_task(queue="dataset")
@@ -96,8 +96,7 @@ def batch_create_segment_to_index_task(
dataset_document.word_count += word_count_change
db.session.add(dataset_document)
# add index to db
indexing_runner = IndexingRunner()
indexing_runner.batch_add_segments(document_segments, dataset)
VectorService.create_segments_vector(None, document_segments, dataset, dataset_document.doc_form)
db.session.commit()
redis_client.setex(indexing_cache_key, 600, "completed")
end_at = time.perf_counter()

+ 1
- 1
api/tasks/clean_dataset_task.py Parādīt failu

@@ -62,7 +62,7 @@ def clean_dataset_task(
if doc_form is None:
raise ValueError("Index type must be specified.")
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
index_processor.clean(dataset, None)
index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=True)

for document in documents:
db.session.delete(document)

+ 1
- 1
api/tasks/clean_document_task.py Parādīt failu

@@ -38,7 +38,7 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
index_processor.clean(dataset, index_node_ids)
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)

for segment in segments:
image_upload_file_ids = get_image_upload_file_ids(segment.content)

+ 1
- 1
api/tasks/clean_notion_document_task.py Parādīt failu

@@ -37,7 +37,7 @@ def clean_notion_document_task(document_ids: list[str], dataset_id: str):
segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all()
index_node_ids = [segment.index_node_id for segment in segments]

index_processor.clean(dataset, index_node_ids)
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)

for segment in segments:
db.session.delete(segment)

+ 19
- 3
api/tasks/deal_dataset_vector_index_task.py Parādīt failu

@@ -4,8 +4,9 @@ import time
import click
from celery import shared_task # type: ignore

from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import Document
from core.rag.models.document import ChildDocument, Document
from extensions.ext_database import db
from models.dataset import Dataset, DocumentSegment
from models.dataset import Document as DatasetDocument
@@ -105,7 +106,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
db.session.commit()

# clean index
index_processor.clean(dataset, None, with_keywords=False)
index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False)

for dataset_document in dataset_documents:
# update from vector index
@@ -128,7 +129,22 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
"dataset_id": segment.dataset_id,
},
)

if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
child_chunks = segment.child_chunks
if child_chunks:
child_documents = []
for child_chunk in child_chunks:
child_document = ChildDocument(
page_content=child_chunk.content,
metadata={
"doc_id": child_chunk.index_node_id,
"doc_hash": child_chunk.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
)
child_documents.append(child_document)
document.children = child_documents
documents.append(document)
# save vector index
index_processor.load(dataset, documents, with_keywords=False)

+ 6
- 16
api/tasks/delete_segment_from_index_task.py Parādīt failu

@@ -6,48 +6,38 @@ from celery import shared_task # type: ignore

from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import Dataset, Document


@shared_task(queue="dataset")
def delete_segment_from_index_task(segment_id: str, index_node_id: str, dataset_id: str, document_id: str):
def delete_segment_from_index_task(index_node_ids: list, dataset_id: str, document_id: str):
"""
Async Remove segment from index
:param segment_id:
:param index_node_id:
:param index_node_ids:
:param dataset_id:
:param document_id:

Usage: delete_segment_from_index_task.delay(segment_id)
Usage: delete_segment_from_index_task.delay(segment_ids)
"""
logging.info(click.style("Start delete segment from index: {}".format(segment_id), fg="green"))
logging.info(click.style("Start delete segment from index", fg="green"))
start_at = time.perf_counter()
indexing_cache_key = "segment_{}_delete_indexing".format(segment_id)
try:
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
if not dataset:
logging.info(click.style("Segment {} has no dataset, pass.".format(segment_id), fg="cyan"))
return

dataset_document = db.session.query(Document).filter(Document.id == document_id).first()
if not dataset_document:
logging.info(click.style("Segment {} has no document, pass.".format(segment_id), fg="cyan"))
return

if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
logging.info(click.style("Segment {} document status is invalid, pass.".format(segment_id), fg="cyan"))
return

index_type = dataset_document.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
index_processor.clean(dataset, [index_node_id])
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)

end_at = time.perf_counter()
logging.info(
click.style("Segment deleted from index: {} latency: {}".format(segment_id, end_at - start_at), fg="green")
)
logging.info(click.style("Segment deleted from index latency: {}".format(end_at - start_at), fg="green"))
except Exception:
logging.exception("delete segment from index failed")
finally:
redis_client.delete(indexing_cache_key)

+ 76
- 0
api/tasks/disable_segments_from_index_task.py Parādīt failu

@@ -0,0 +1,76 @@
import logging
import time

import click
from celery import shared_task

from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import Dataset, DocumentSegment
from models.dataset import Document as DatasetDocument


@shared_task(queue="dataset")
def disable_segments_from_index_task(segment_ids: list, dataset_id: str, document_id: str):
"""
Async disable segments from index
:param segment_ids:

Usage: disable_segments_from_index_task.delay(segment_ids, dataset_id, document_id)
"""
start_at = time.perf_counter()

dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
if not dataset:
logging.info(click.style("Dataset {} not found, pass.".format(dataset_id), fg="cyan"))
return

dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == document_id).first()

if not dataset_document:
logging.info(click.style("Document {} not found, pass.".format(document_id), fg="cyan"))
return
if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
logging.info(click.style("Document {} status is invalid, pass.".format(document_id), fg="cyan"))
return
# sync index processor
index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor()

segments = (
db.session.query(DocumentSegment)
.filter(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset_id,
DocumentSegment.document_id == document_id,
)
.all()
)

if not segments:
return

try:
index_node_ids = [segment.index_node_id for segment in segments]
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=False)

end_at = time.perf_counter()
logging.info(click.style("Segments removed from index latency: {}".format(end_at - start_at), fg="green"))
except Exception:
# update segment error msg
db.session.query(DocumentSegment).filter(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset_id,
DocumentSegment.document_id == document_id,
).update(
{
"disabled_at": None,
"disabled_by": None,
"enabled": True,
}
)
db.session.commit()
finally:
for segment in segments:
indexing_cache_key = "segment_{}_indexing".format(segment.id)
redis_client.delete(indexing_cache_key)

+ 1
- 1
api/tasks/document_indexing_sync_task.py Parādīt failu

@@ -82,7 +82,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
index_node_ids = [segment.index_node_id for segment in segments]

# delete from vector index
index_processor.clean(dataset, index_node_ids)
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)

for segment in segments:
db.session.delete(segment)

+ 1
- 1
api/tasks/document_indexing_update_task.py Parādīt failu

@@ -47,7 +47,7 @@ def document_indexing_update_task(dataset_id: str, document_id: str):
index_node_ids = [segment.index_node_id for segment in segments]

# delete from vector index
index_processor.clean(dataset, index_node_ids)
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)

for segment in segments:
db.session.delete(segment)

+ 3
- 3
api/tasks/duplicate_document_indexing_task.py Parādīt failu

@@ -51,7 +51,7 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list):
if document:
document.indexing_status = "error"
document.error = str(e)
document.stopped_at = datetime.datetime.utcnow()
document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
db.session.add(document)
db.session.commit()
return
@@ -73,14 +73,14 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list):
index_node_ids = [segment.index_node_id for segment in segments]

# delete from vector index
index_processor.clean(dataset, index_node_ids)
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)

for segment in segments:
db.session.delete(segment)
db.session.commit()

document.indexing_status = "parsing"
document.processing_started_at = datetime.datetime.utcnow()
document.processing_started_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
documents.append(document)
db.session.add(document)
db.session.commit()

+ 18
- 1
api/tasks/enable_segment_to_index_task.py Parādīt failu

@@ -6,8 +6,9 @@ import click
from celery import shared_task # type: ignore
from werkzeug.exceptions import NotFound

from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import Document
from core.rag.models.document import ChildDocument, Document
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import DocumentSegment
@@ -61,6 +62,22 @@ def enable_segment_to_index_task(segment_id: str):
return

index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor()
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
child_chunks = segment.child_chunks
if child_chunks:
child_documents = []
for child_chunk in child_chunks:
child_document = ChildDocument(
page_content=child_chunk.content,
metadata={
"doc_id": child_chunk.index_node_id,
"doc_hash": child_chunk.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
)
child_documents.append(child_document)
document.children = child_documents
# save vector index
index_processor.load(dataset, [document])


+ 108
- 0
api/tasks/enable_segments_to_index_task.py Parādīt failu

@@ -0,0 +1,108 @@
import datetime
import logging
import time

import click
from celery import shared_task

from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import ChildDocument, Document
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import Dataset, DocumentSegment
from models.dataset import Document as DatasetDocument


@shared_task(queue="dataset")
def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_id: str):
"""
Async enable segments to index
:param segment_ids:

Usage: enable_segments_to_index_task.delay(segment_ids)
"""
start_at = time.perf_counter()
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
if not dataset:
logging.info(click.style("Dataset {} not found, pass.".format(dataset_id), fg="cyan"))
return

dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == document_id).first()

if not dataset_document:
logging.info(click.style("Document {} not found, pass.".format(document_id), fg="cyan"))
return
if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
logging.info(click.style("Document {} status is invalid, pass.".format(document_id), fg="cyan"))
return
# sync index processor
index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor()

segments = (
db.session.query(DocumentSegment)
.filter(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset_id,
DocumentSegment.document_id == document_id,
)
.all()
)
if not segments:
return

try:
documents = []
for segment in segments:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": document_id,
"dataset_id": dataset_id,
},
)

if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
child_chunks = segment.child_chunks
if child_chunks:
child_documents = []
for child_chunk in child_chunks:
child_document = ChildDocument(
page_content=child_chunk.content,
metadata={
"doc_id": child_chunk.index_node_id,
"doc_hash": child_chunk.index_node_hash,
"document_id": document_id,
"dataset_id": dataset_id,
},
)
child_documents.append(child_document)
document.children = child_documents
documents.append(document)
# save vector index
index_processor.load(dataset, documents)

end_at = time.perf_counter()
logging.info(click.style("Segments enabled to index latency: {}".format(end_at - start_at), fg="green"))
except Exception as e:
logging.exception("enable segments to index failed")
# update segment error msg
db.session.query(DocumentSegment).filter(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset_id,
DocumentSegment.document_id == document_id,
).update(
{
"error": str(e),
"status": "error",
"disabled_at": datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None),
"enabled": False,
}
)
db.session.commit()
finally:
for segment in segments:
indexing_cache_key = "segment_{}_indexing".format(segment.id)
redis_client.delete(indexing_cache_key)

+ 1
- 1
api/tasks/remove_document_from_index_task.py Parādīt failu

@@ -43,7 +43,7 @@ def remove_document_from_index_task(document_id: str):
index_node_ids = [segment.index_node_id for segment in segments]
if index_node_ids:
try:
index_processor.clean(dataset, index_node_ids)
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=False)
except Exception:
logging.exception(f"clean dataset {dataset.id} from index failed")


+ 7
- 7
api/tasks/retry_document_indexing_task.py Parādīt failu

@@ -48,7 +48,7 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]):
if document:
document.indexing_status = "error"
document.error = str(e)
document.stopped_at = datetime.datetime.utcnow()
document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
db.session.add(document)
db.session.commit()
redis_client.delete(retry_indexing_cache_key)
@@ -69,14 +69,14 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]):
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
# delete from vector index
index_processor.clean(dataset, index_node_ids)
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)

for segment in segments:
db.session.delete(segment)
db.session.commit()
for segment in segments:
db.session.delete(segment)
db.session.commit()

document.indexing_status = "parsing"
document.processing_started_at = datetime.datetime.utcnow()
document.processing_started_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
db.session.add(document)
db.session.commit()

@@ -86,7 +86,7 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]):
except Exception as ex:
document.indexing_status = "error"
document.error = str(ex)
document.stopped_at = datetime.datetime.utcnow()
document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
db.session.add(document)
db.session.commit()
logging.info(click.style(str(ex), fg="yellow"))

+ 7
- 7
api/tasks/sync_website_document_indexing_task.py Parādīt failu

@@ -46,7 +46,7 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str):
if document:
document.indexing_status = "error"
document.error = str(e)
document.stopped_at = datetime.datetime.utcnow()
document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
db.session.add(document)
db.session.commit()
redis_client.delete(sync_indexing_cache_key)
@@ -65,14 +65,14 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str):
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
# delete from vector index
index_processor.clean(dataset, index_node_ids)
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)

for segment in segments:
db.session.delete(segment)
db.session.commit()
for segment in segments:
db.session.delete(segment)
db.session.commit()

document.indexing_status = "parsing"
document.processing_started_at = datetime.datetime.utcnow()
document.processing_started_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
db.session.add(document)
db.session.commit()

@@ -82,7 +82,7 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str):
except Exception as ex:
document.indexing_status = "error"
document.error = str(ex)
document.stopped_at = datetime.datetime.utcnow()
document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
db.session.add(document)
db.session.commit()
logging.info(click.style(str(ex), fg="yellow"))

+ 98
- 0
api/templates/clean_document_job_mail_template-US.html Parādīt failu

@@ -0,0 +1,98 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Documents Disabled Notification</title>
<style>
body {
font-family: Arial, sans-serif;
margin: 0;
padding: 0;
background-color: #f5f5f5;
}
.email-container {
max-width: 600px;
margin: 20px auto;
background: #ffffff;
border-radius: 10px;
box-shadow: 0 2px 10px rgba(0, 0, 0, 0.1);
overflow: hidden;
}
.header {
background-color: #eef2fa;
padding: 20px;
text-align: center;
}
.header img {
height: 40px;
}
.content {
padding: 20px;
line-height: 1.6;
color: #333;
}
.content h1 {
font-size: 24px;
color: #222;
}
.content p {
margin: 10px 0;
}
.content ul {
padding-left: 20px;
}
.content ul li {
margin-bottom: 10px;
}
.cta-button {
display: block;
margin: 20px auto;
padding: 10px 20px;
background-color: #4e89f9;
color: #ffffff;
text-align: center;
text-decoration: none;
border-radius: 5px;
width: fit-content;
}
.footer {
text-align: center;
padding: 10px;
font-size: 12px;
color: #777;
background-color: #f9f9f9;
}
</style>
</head>
<body>
<div class="email-container">
<!-- Header -->
<div class="header">
<img src="https://via.placeholder.com/150x40?text=Dify" alt="Dify Logo">
</div>

<!-- Content -->
<div class="content">
<h1>Some Documents in Your Knowledge Base Have Been Disabled</h1>
<p>Dear {{userName}},</p>
<p>
We're sorry for the inconvenience. To ensure optimal performance, documents
that haven’t been updated or accessed in the past 7 days have been disabled in
your knowledge bases:
</p>
<ul>
{{knowledge_details}}
</ul>
<p>You can re-enable them anytime.</p>
<a href={{url}} class="cta-button">Re-enable in Dify</a>
</div>

<!-- Footer -->
<div class="footer">
Sincerely,<br>
The Dify Team
</div>
</div>
</body>
</html>

Notiek ielāde…
Atcelt
Saglabāt